In [47]:
%%writefile intrinsic_main.c

#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <stdlib.h>
#include <sys/time.h>

// External declarations from intrinsic.c
extern int SneakySnake(int EditThreshold, char* ReadSeq, char* RefSeq, int ReadLength, int IterationNo);
extern uint64_t best_diagonal_score;

typedef struct {
    char* read_seq;
    char* ref_seq;
    int length;
} SequencePair;

int read_sequences_from_file(const char* filename, SequencePair** pairs, int* count) {
    FILE* fp = fopen(filename, "r");
    if (!fp) {
        fprintf(stderr, "Error", filename);
        return -1;
    }

    // Count lines first
    int lines = 0;
    char ch;
    while ((ch = fgetc(fp)) != EOF) {
        if (ch == '\n') lines++;
    }
    
    *count = lines;
    *pairs = (SequencePair*)malloc(sizeof(SequencePair) * (*count));
    
    rewind(fp);
    
    char line[512]; 
    int pair_idx = 0;
    
    while (fgets(line, sizeof(line), fp)) {
        line[strcspn(line, "\r\n")] = 0;
        
        char *separator = strchr(line, '\t');
        if (!separator) {
            separator = strchr(line, ' ');
        }
        
        if (!separator) {
            fprintf(stderr, "Warning: Line %d has no separator, skipping\n", pair_idx + 1);
            continue;
        }
        
        *separator = '\0';
        char *read = line;
        char *ref = separator + 1;
        
        int len = strlen(read);
        (*pairs)[pair_idx].length = len;
        (*pairs)[pair_idx].read_seq = (char*)malloc(len + 1);
        (*pairs)[pair_idx].ref_seq = (char*)malloc(len + 1);
        strcpy((*pairs)[pair_idx].read_seq, read);
        strcpy((*pairs)[pair_idx].ref_seq, ref);
        
        pair_idx++;
    }
    
    fclose(fp);
    *count = pair_idx;
    return 0;
}

int main(int argc, char* argv[]) {
    if (argc < 3) {
        fprintf(stderr, "Usage: %s <input_file> <edit_threshold> [KmerSize] [IterationNo]\n", argv[0]);
        fprintf(stderr, "  input_file: File with read/ref pairs (one pair per line, tab/space separated)\n");
        fprintf(stderr, "  edit_threshold: Maximum edit distance to check\n");
        fprintf(stderr, "  KmerSize: Optional kmer size (default: 100)\n");
        fprintf(stderr, "  IterationNo: Optional iterations (default: 100)\n");
        return 1;
    }

    const char* input_file = argv[1];
    int EditThreshold = atoi(argv[2]);
    int KmerSize = 100;
    int IterationNo = 100;
    
    if (argc >= 4) KmerSize = atoi(argv[3]);
    if (argc >= 5) IterationNo = atoi(argv[4]);

    SequencePair* pairs = NULL;
    int pair_count = 0;

    if (read_sequences_from_file(input_file, &pairs, &pair_count) != 0) {
        return 1;
    }

    printf("Loaded %d sequence pairs from %s\n", pair_count, input_file);
    printf("Edit Threshold: %d, KmerSize: %d, IterationNo: %d\n", EditThreshold, KmerSize, IterationNo);

    int total_accepted = 0;
    int total_rejected = 0;

    // Start timing
    struct timeval start_time, end_time;
    gettimeofday(&start_time, NULL);

    // Process each pair
    for (int i = 0; i < pair_count; i++) {
        char* ReadSeq = strdup(pairs[i].read_seq);
        char* RefSeq = strdup(pairs[i].ref_seq);
        int len = pairs[i].length;

        int result = SneakySnake(EditThreshold, ReadSeq, RefSeq, len, IterationNo);

        if (result) {
            total_accepted++;
        } else {
            total_rejected++;
        }

        free(ReadSeq);
        free(RefSeq);
    }

    // End timing
    gettimeofday(&end_time, NULL);
    long elapsed_ms = (end_time.tv_sec - start_time.tv_sec) * 1000 + 
                      (end_time.tv_usec - start_time.tv_usec) / 1000;

    printf("\nResults:\n");
    printf("  Total pairs: %d\n", pair_count);
    printf("  Accepted: %d\n", total_accepted);
    printf("  Rejected: %d\n", total_rejected);
    printf("  Time: %ld milliseconds\n", elapsed_ms);
    
    // Cleanup
    for (int i = 0; i < pair_count; i++) {
        free(pairs[i].read_seq);
        free(pairs[i].ref_seq);
    }
    free(pairs);

    return 0;
}

Overwriting intrinsic_main.c


In [48]:
%%writefile intrinsic.c

#include <immintrin.h>
#include <stdint.h>
#include <string.h>
#include <stdio.h>

uint64_t best_diagonal_score = 0;

static inline int count_consecutive_matches_avx512(char* read, char* ref, int start, int end) {
    int count = 0;
    int i = start;
    
    // Process 64 bytes at a time
    for (; i <= end - 64; i += 64) {
        __m512i read_vec = _mm512_loadu_si512((__m512i*)(read + i));
        __m512i ref_vec = _mm512_loadu_si512((__m512i*)(ref + i));
        __mmask64 cmp_mask = _mm512_cmpeq_epi8_mask(read_vec, ref_vec);
        
        if (cmp_mask == 0xFFFFFFFFFFFFFFFF) {
            count += 64;
        } else {
            count += __builtin_ctzll(~cmp_mask);
            return count;
        }
    }
    
    // Process remaining bytes
    for (; i < end; i++) {
        if (read[i] == ref[i]) {
            count++;
        } else {
            break;
        }
    }
    
    return count;
}

static inline int count_diagonal_matches_avx512(char* read, char* ref, int start, int end, int shift, int read_length, int is_right_diag) {
    int count = 0;
    int i = start;
    
    if (is_right_diag) {
        // Right diagonal (deletion)
        // read[i-shift] vs ref[i]
        for (; i <= end - 64; i += 64) {
            if (i - shift < 0) break;
            
            __m512i read_vec = _mm512_loadu_si512((__m512i*)(read + i - shift));
            __m512i ref_vec = _mm512_loadu_si512((__m512i*)(ref + i));
            __mmask64 cmp_mask = _mm512_cmpeq_epi8_mask(read_vec, ref_vec);
            
            if (cmp_mask == 0xFFFFFFFFFFFFFFFF) {
                count += 64;
            } else {
                count += __builtin_ctzll(~cmp_mask);
                return count;
            }
        }
        
        // Handle remaining bytes
        for (; i < end; i++) {
            int read_pos = i - shift;
            if (read_pos < 0) break;
            if (read[read_pos] == ref[i]) {
                count++;
            } else {
                break;
            }
        }
    } else {
        // Left diagonal (insertion)
        // read[i+shift] vs ref[i]  
        for (; i <= end - 64; i += 64) {
            if (i + shift + 64 > read_length) break;
            
            __m512i read_vec = _mm512_loadu_si512((__m512i*)(read + i + shift));
            __m512i ref_vec = _mm512_loadu_si512((__m512i*)(ref + i));
            __mmask64 cmp_mask = _mm512_cmpeq_epi8_mask(read_vec, ref_vec);
            
            if (cmp_mask == 0xFFFFFFFFFFFFFFFF) {
                count += 64;
            } else {
                count += __builtin_ctzll(~cmp_mask);
                return count;
            }
        }
        
        // Handle remaining bytes
        for (; i < end; i++) {
            int read_pos = i + shift;
            if (read_pos >= read_length) break;
            if (read[read_pos] == ref[i]) {
                count++;
            } else {
                break;
            }
        }
    }
    
    return count;
}

int SneakySnake(int EditThreshold, char* ReadSeq, char* RefSeq, int ReadLength, int IterationNo)
{
    int Edits = 0;
    
    int KmerSize = 100;
    int NumKmers = ReadLength / KmerSize;
    if (NumKmers == 0) {
        NumKmers = 1;
        KmerSize = ReadLength;
    }
    
    for (int K = 0; K < NumKmers; K++) {
        int KmerStart = K * KmerSize;
        int KmerEnd = (K < NumKmers - 1) ? (K + 1) * KmerSize : ReadLength;
        
        int index = KmerStart;
        int roundsNo = 1;
        
        while (index < KmerEnd) {
            int GlobalCount = 0;
            
            // Check main diagonal first
            GlobalCount = count_consecutive_matches_avx512(ReadSeq, RefSeq, index, KmerEnd);
            
            if (GlobalCount == (KmerEnd - index)) {
                break; // Perfect match in this segment
            }
            
            // Check diagonals within edit threshold
            for (int e = 1; e <= EditThreshold; e++) {
                int count = 0;
                
                // Right diagonal
                count = count_diagonal_matches_avx512(ReadSeq, RefSeq, index, KmerEnd, e, ReadLength, 1);
                if (count > GlobalCount) GlobalCount = count;
                if (GlobalCount == (KmerEnd - index)) break;
                
                // Left diagonal
                count = count_diagonal_matches_avx512(ReadSeq, RefSeq, index, KmerEnd, e, ReadLength, 0);
                if (count > GlobalCount) GlobalCount = count;
                if (GlobalCount == (KmerEnd - index)) break;
            }
            
            // Move index forward based on matches found
            index += GlobalCount;
            if (index < KmerEnd) {
                Edits++;
                index++;
                
                if (Edits > EditThreshold) {
                    best_diagonal_score = Edits;
                    return 0; // Rejected
                }
            }
            
            if (roundsNo++ > IterationNo) break;
        }
        
        if (Edits > EditThreshold) {
            best_diagonal_score = Edits;
            return 0; // Rejected
        }
    }
    
    best_diagonal_score = Edits;
    return 1; // Accepted
}

Overwriting intrinsic.c


In [54]:
!gcc -O3 -march=native -mavx512f -mavx512bw -w -o intrinsic_ver intrinsic_main.c intrinsic.c
# (caller) - (dataset file) - (edit distance) - (kmer size) - (iteration number)
!./intrinsic_ver SRR826471_1_E100_30million.txt 2 250 100

Loaded 30000000 sequence pairs from ERR240727_1_E2_30million.txt
Edit Threshold: 2, KmerSize: 250, IterationNo: 100

Results:
  Total pairs: 30000000
  Accepted: 3343620
  Rejected: 26656380
  Time: 3947 milliseconds
