### BWT Homework

Today we will construct BWT from scratch. First we will construct a suffix array for a given string using pysuffixarray.


In [1]:
!pip install pysuffixarray

from pysuffixarray.core import SuffixArray
sa = SuffixArray('ACAACG')
print(sa.suffix_array())

Collecting pysuffixarray
  Downloading pysuffixarray-0.0.1.tar.gz (2.8 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: pysuffixarray
  Building wheel for pysuffixarray (setup.py) ... [?25ldone
[?25h  Created wheel for pysuffixarray: filename=pysuffixarray-0.0.1-py3-none-any.whl size=3482 sha256=03ca6792f4da6e3800203f11c84c70aacba5c5a17fc2abbd2aa3eb7e45e8077d
  Stored in directory: /Users/polina/Library/Caches/pip/wheels/3e/50/c2/b3e3f16ef336e594fcf1c604b84b4c88a72f79b7f9cb18e3b3
Successfully built pysuffixarray
Installing collected packages: pysuffixarray
Successfully installed pysuffixarray-0.0.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[6, 2, 0, 3, 1, 4, 5]


In [15]:
from Bio import SeqIO
import numpy as np
from pysuffixarray.core import SuffixArray
from collections import defaultdict, Counter

## Task 1: Create BWT using suffix array:

- Using BioPython upload SARS-CoV-2 reference genome from fasta file (genome.fa in BWT_folder)
- Construct suffix array
- Construct BWT from suffix array
- Don't forget to add special symbol (but after SA construction)

![correct](BWT_folder/BWT1.png)

In [11]:
# Load the reference genome and append $

ref_path = "BWT_folder/genome.fa"  
with open(ref_path, "r") as handle:
    record = next(SeqIO.parse(handle, "fasta"))
    ref_seq = str(record.seq)  
print(f"Reference length = {len(ref_seq)}")


# Build the SA using pysuffixarray

sa_obj = SuffixArray(ref_seq)
SA = sa_obj.suffix_array() 
print(f"First 10 entries of SA: {SA[:10]}")

Reference length = 29903
First 10 entries of SA: [29903, 29902, 29901, 29900, 29899, 29898, 29897, 29896, 29895, 29894]


In [16]:
# Construct the BWT from SA

n = len(ref_seq)
bwt_list = []
for idx in SA:
    if idx == 0:
        bwt_list.append("$")          
    else:
        bwt_list.append(ref_seq[idx - 1])

BWT = "".join(bwt_list)
print("First 10 chars of BWT:", BWT[:10])

First 10 chars of BWT: AAAAAAAAAA


## Task 2: Create FM index
- Construct Occurence array
- Construct Count dictionary
- Make a class BWTSearcher

![correct](BWT_folder/BWT3.png)

In [33]:
# Your impletementation of BWT

class BWTSearcher:
    def __init__(self, reference):
        if isinstance(reference, str) and (reference.endswith(".fa") or reference.endswith(".fasta")):
            rec = next(SeqIO.parse(reference, "fasta"))
            raw_seq = str(rec.seq)
        else:
            raw_seq = str(reference)
            
        self.ref = raw_seq
        
        sa_obj = SuffixArray(self.ref)
        self.SA = list(sa_obj.suffix_array())
        self.n   = len(self.SA)  

        bwt_chars = []
        
        for pos in self.SA:
            if pos == 0:
                bwt_chars.append("$")
            else:
                bwt_chars.append(self.ref[pos - 1])
        
        self.BWT = "".join(bwt_chars)

        total_counts = Counter(self.BWT)
        alphabet = sorted(total_counts.keys())

        self.Count = {}
        cum = 0

        for ch in alphabet:
            self.Count[ch] = cum
            cum += total_counts[ch]
        
        self.Occ = {ch: [0] * self.n for ch in alphabet}
        running_counts = {ch: 0 for ch in alphabet}

        for i, bch in enumerate(self.BWT):
            running_counts[bch] += 1
            for ch in alphabet:
                self.Occ[ch][i] = running_counts[ch]
    
    
    def _occ(self, ch, idx):
        if idx < 0:
            return 0
        return self.Occ[ch][idx]

    def bwt_pattern_search(self, pattern):
        m = len(pattern)
        # Initialize the full interval [sp, ep]
        sp = 0
        ep = self.n - 1

        for j in range(m - 1, -1, -1):
            c = pattern[j]
            if c not in self.Count:
                # Character not even in our alphabet - no match possible
                return []
            
            sp = self.Count[c] + self._occ(c, sp - 1)
            ep = self.Count[c] + self._occ(c, ep) - 1
            
            # If the interval becomes invalid, there is no match
            if sp > ep:
                return []
    
        matched_positions = self.SA[sp : ep + 1]
        return sorted(matched_positions)


### Task 4:
- There are 100 reads that were randomly sampled from genome.fa
- Some of them are error free, some contain one mutation, and some contain 5 mutations
- Could you use your BWTSearcher class to classify them? Think about the solution and implement it. You can add any functions of class members
- How many reads of each class did you find?

In [34]:
def hamming_distance(s1, s2):
    """
    Compute the Hamming distance between two strings s1 and s2 of equal length.
    """
    return sum(ch1 != ch2 for ch1, ch2 in zip(s1, s2))


class BWTSearcher(BWTSearcher):
    def get_substring(self, pos, length):
        """
        Return the substring of the reference of length 'length', starting at index pos.
        """
        if pos + length > len(self.ref) - 1:  
            # we avoid crossing the '$' character
            return None
        return self.ref[pos : pos + length]
    
    def find_candidates_prefix(self, prefix):
        """
        Return all starting positions in the reference where 'prefix' occurs exactly.
        """
        return self.bwt_pattern_search(prefix)

In [39]:
from Bio import SeqIO

# Counters for each class
count_error_free = 0
count_one_mut = 0
count_five_mut = 0
count_unk_mut = 0

reference_fasta = "BWT_folder/genome.fa"
searcher = BWTSearcher(reference_fasta)

# Process each read through the BWTSearcher
with open("BWT_folder/sample_reads.fasta", "r") as fh:
    for record in SeqIO.parse(fh, "fasta"):
        read_id   = record.id
        read_seq  = str(record.seq)
        m         = len(read_seq)
        mid       = m // 2
        left_half  = read_seq[:mid]
        right_half = read_seq[mid:]

        # 0 mismatches
        exact_hits = searcher.bwt_pattern_search(read_seq)
        if exact_hits:
            count_error_free += 1
            continue

        # 1 mismatch
        found_one_mismatch = False

        cand_left = searcher.bwt_pattern_search(left_half)
        for pos in cand_left:
            ref_sub = searcher.get_substring(pos, m)
            if ref_sub is None:
                continue
            if hamming_distance(read_seq, ref_sub) == 1:
                found_one_mismatch = True
                break

        if not found_one_mismatch:
            cand_right = searcher.bwt_pattern_search(right_half)
            for pos_r in cand_right:
                full_start = pos_r - mid
                if full_start < 0:
                    continue
                ref_sub = searcher.get_substring(full_start, m)
                if ref_sub is None:
                    continue
                if hamming_distance(read_seq, ref_sub) == 1:
                    found_one_mismatch = True
                    break

        if found_one_mismatch:
            count_one_mut += 1
            continue
        
        # 5 mismatches
        found_five_mismatch = False

        cand_left = searcher.bwt_pattern_search(left_half)

        for pos in cand_left:
            ref_sub = searcher.get_substring(pos, m)
            if ref_sub is None:
                continue
            if hamming_distance(read_seq, ref_sub) == 5:
                found_five_mismatch = True
                break

        if not found_five_mismatch:
            cand_right = searcher.bwt_pattern_search(right_half)
            for pos_r in cand_right:
                full_start = pos_r - mid
                if full_start < 0:
                    continue
                ref_sub = searcher.get_substring(full_start, m)
                if ref_sub is None:
                    continue
                if hamming_distance(read_seq, ref_sub) == 5:
                    found_five_mismatch = True
                    break

        if found_five_mismatch:
            count_five_mut += 1

        # neither 0, 1, nor 5 mismatches
        else:
            count_unk_mut += 1


print("===== Classification Summary =====")
print(f"Reads with 0 mismatches: {count_error_free}")
print(f"Reads with 1 mismatches: {count_one_mut}")
print(f"Reads with 5 mismatches: {count_five_mut}")
print(f"Reads with neither 0, 1, nor 5 mismatches: {count_unk_mut}")

print("----------------------------------")
print(f"Total classified (should be 100): {count_error_free + count_one_mut + count_five_mut + count_unk_mut}")

===== Classification Summary =====
Reads with 0 mismatches: 45
Reads with 1 mismatches: 31
Reads with 5 mismatches: 2
Reads with neither 0, 1, nor 5 mismatches: 22
----------------------------------
Total classified (should be 100): 100
