# Burrows Wheeler Transform Algorithm (BWA)

The Burrows Wheeler Transform (BWT) was developed in 1994 by Michael Burrows and David Wheeler. In simple terms, BWT is a string transformation that acts as a preprocessing step for lossless compression. BWT has implementations that exhibit both a linear O(n) performance and space complexity. Originally designed to prepare data for compression with techniques like bzip2, BWT has found prominence in bioinformatics allowing the fast mapping of short reads paving the way for high throughput genetic sequencing.
> https://towardsdatascience.com/burrows-wheeler-in-python-c07cbf71b3f0

## Burrows Wheeler Transform (BWT)

The Burrows–Wheeler transform (BWT, also called block-sorting compression) rearranges a character string into runs of similar characters. This is useful for compression, since it tends to be easy to compress a string that has runs of repeated characters by techniques such as move-to-front transform and run-length encoding.

When a character string is transformed by the BWT, the transformation permutes the order of the characters. If the original string had several substrings that occurred often, then the transformed string will have several places where a single character is repeated multiple times in a row.

> https://en.wikipedia.org/wiki/Burrows%E2%80%93Wheeler_transform

1. List the all rotations with the sequence input. \
(where, '$' represents the end of sequence.)

In [2]:
from pprint import pprint

example = "OMICSSBS$"
rotations = []
for i in range(len(example)):
    rotations.append(example[i:] + example[:i])
pprint(rotations)

['OMICSSBS$',
 'MICSSBS$O',
 'ICSSBS$OM',
 'CSSBS$OMI',
 'SSBS$OMIC',
 'SBS$OMICS',
 'BS$OMICSS',
 'S$OMICSSB',
 '$OMICSSBS']


2. Then we sort the rotations alphabetically or any property you want.

In [3]:
sorted_rotations = sorted(rotations)
pprint(sorted_rotations)

['$OMICSSBS',
 'BS$OMICSS',
 'CSSBS$OMI',
 'ICSSBS$OM',
 'MICSSBS$O',
 'OMICSSBS$',
 'S$OMICSSB',
 'SBS$OMICS',
 'SSBS$OMIC']


3. We get the last column of the sorted rotations.

In [4]:
for seq in sorted_rotations:
    pprint(seq[-1])

'S'
'S'
'I'
'M'
'O'
'$'
'B'
'S'
'C'


## FM-index

In computer science, an FM-index is a compressed full-text substring index based on the Burrows–Wheeler transform, with some similarities to the suffix array. 

An FM-index is created by first taking the Burrows–Wheeler transform (BWT) of the input text. For example, the BWT of the string T = "abracadabra$" is "ard$rcaaaabb", and here it is represented by the matrix M where each row is a rotation of the text, and the rows have been sorted lexicographically. The transform corresponds to the last column labeled L, And the first colum labeled F.

> https://en.wikipedia.org/wiki/FM-index

In [5]:
import numpy as np

initial_indices = np.argsort(rotations)
for ind, seq in zip(initial_indices, sorted_rotations):
    print(f"Initial index: {ind} of the sequence: {seq}")

Initial index: 8 of the sequence: $OMICSSBS
Initial index: 6 of the sequence: BS$OMICSS
Initial index: 3 of the sequence: CSSBS$OMI
Initial index: 2 of the sequence: ICSSBS$OM
Initial index: 1 of the sequence: MICSSBS$O
Initial index: 0 of the sequence: OMICSSBS$
Initial index: 7 of the sequence: S$OMICSSB
Initial index: 5 of the sequence: SBS$OMICS
Initial index: 4 of the sequence: SSBS$OMIC


We can keep the last column in a sequence.

In [6]:
last_column = ''
for seq in sorted_rotations:
    last_column += seq[-1]
print(last_column) 

SSIMO$BSC


Then, we count the number of occurrences of each character in the prefix sequence.

In [7]:
totals = {k: 0 for k in "".join(set(last_column))}
tallymatrix = {k: [] for k in "".join(set(last_column))}

for i in last_column:
    totals[i] += 1
    for j in tallymatrix.keys():
        if i != j and tallymatrix[j]:
            tallymatrix[j].append(tallymatrix[j][-1])
        elif i == j:
            tallymatrix[j].append(totals[i])
        else:
            tallymatrix[j].append(0)

pprint(totals)
pprint(tallymatrix)

{'$': 1, 'B': 1, 'C': 1, 'I': 1, 'M': 1, 'O': 1, 'S': 3}
{'$': [0, 0, 0, 0, 0, 1, 1, 1, 1],
 'B': [0, 0, 0, 0, 0, 0, 1, 1, 1],
 'C': [0, 0, 0, 0, 0, 0, 0, 0, 1],
 'I': [0, 0, 1, 1, 1, 1, 1, 1, 1],
 'M': [0, 0, 0, 1, 1, 1, 1, 1, 1],
 'O': [0, 0, 0, 0, 1, 1, 1, 1, 1],
 'S': [1, 2, 2, 2, 2, 2, 2, 3, 3]}


Using this we can rebuild a index of where we start and stop seeing the characters in the first column.

In [8]:
first = {}
totc = 0
for i, count in sorted(totals.items()):
    first[i] = (totc, totc+count)
    totc += count
pprint(first)

{'$': (0, 1),
 'B': (1, 2),
 'C': (2, 3),
 'I': (3, 4),
 'M': (4, 5),
 'O': (5, 6),
 'S': (6, 9)}


### Last to First Mapping

Then, we can map a given character with an index i back to first column.

To get a better grasp of whats happening lets track the index jumps too.

In [30]:
i = 0
t = "$"
j = 0
print(f"loop\tloopindex\tindex\tinitial_ind\tfirst_column\tlast_column\tc\tt\n")
while last_column[i] != "$":
    c = last_column[i]
    t = c + t
    i = first[c][0] + tallymatrix[c][i] - 1
    j += 1
    print(f"END{j:10} {i:13} {initial_indices[j]:10}{' '*15}{[seq[0] for seq in sorted_rotations][j]:15} {[seq[-1] for seq in sorted_rotations][j]:10} {c:8} {t}")

print("\n"+'first:')
pprint(first)
print("\n"+'last column:')
pprint(last_column)
print("\n"+'tallymatrix:')
pprint(tallymatrix)
print("\n"+'initial indices:')
pprint(list(initial_indices))
print("\n"+'sorted rotations:')
pprint(sorted_rotations)

loop	loopindex	index	initial_ind	first_column	last_column	c	t

END         1             6          6               B               S          S        S$
END         2             1          3               C               I          B        BS$
END         3             7          2               I               M          S        SBS$
END         4             8          1               M               O          S        SSBS$
END         5             2          0               O               $          C        CSSBS$
END         6             3          7               S               B          I        ICSSBS$
END         7             4          5               S               S          M        MICSSBS$
END         8             5          4               S               C          O        OMICSSBS$

first:
{'$': (0, 1),
 'B': (1, 2),
 'C': (2, 3),
 'I': (3, 4),
 'M': (4, 5),
 'O': (5, 6),
 'S': (6, 9)}

last column:
'SSIMO$BSC'

tallymatrix:
{'$': [0, 0, 0, 0, 0, 1, 1,

You can see the output matrix above clearly. We get the last character s, then we search which `sorted_rotations` ends with s through `first[c][0] + tallymatrix[c][i] - 1`. Implementing `first[c][0]` is to get the index of rotation which starts with s in `sorted_rotations` and implementing `tallymatrix[c][i]` is to get the s which I really want, because the `tallmatrix` is calculated based on the `last_column` which means that the s occur earlier in the start of `sorted_rotations` if the s occur earlier in `last_column`. So we can get our initial sequence through this way.

In [31]:
i = 0
t = "$"
while last_column[i] != "$":
    c = last_column[i]
    t = c + t
    i = first[c][0] + tallymatrix[c][i] - 1
print(t)

OMICSSBS$


## Building a class

So now, we can build a class `BWA` which contains 4 core data structures used in the algorithm:

1. Suffix Array
2. BWT
3. C: C[c] is a table that, for each character c in the alphabet, contains the number of occurrences of lexically smaller characters in the text.
4. Occ: The function Occ(c, k) is the number of occurrences of character c in the prefix L[1..k].

In [32]:
class BWA:
    """ A Burrows-Wheeler Alignment class. """

    def __init__(self, reference: str):
        """ Initiation """
        self.ref = reference + "$"
        self.alphabet = sorted(['a', 'g', 't', 'c'])

    def suffix_array(self):
        """ Get the suffix array of the reference. """
        # List the all rotations with the reference input.
        rotations = []
        for i in range(len(self.ref)):
            rotations.append(self.ref[i:] + self.ref[:i])
        # Sort the rotations alphabetically
        sorted_rotations = sorted(rotations)
        # Gain the initial index of the sorted rotations
        initial_indices = np.argsort(rotations)
        return list(initial_indices), sorted_rotations

    def bwt(self):
        """ Get the Burrows–Wheeler transform array. """
        # Keep the last column in a sequence
        last = ''
        _, sorted_rotations = self.suffix_array()
        for seq in sorted_rotations:
            last += seq[-1]
        return last

    def Occ(self):
        """ Get the Occ matrix.
            Occ(c, k) is the number of occurrences of character c in the prefix L[1..k].
        """
        # Gain the bwt last column
        last_column = self.bwt()
        # Initiation
        totals = {k: 0 for k in "".join(set(last_column))}
        tally_matrix = {k: [] for k in "".join(set(last_column))}
        # Get the Occ matrix
        for i in last_column:
            totals[i] += 1
            for j in tally_matrix.keys():
                if i != j and tally_matrix[j]:
                    tally_matrix[j].append(tally_matrix[j][-1])
                elif i == j:
                    tally_matrix[j].append(totals[i])
                else:
                    tally_matrix[j].append(0)
        return tally_matrix, totals

    def C(self):
        """ Get the C table.
            C[c] is a table that, for each character c in the alphabet,
            contains the number of occurrences of lexically smaller characters in the text.
        """
        first = {}
        totc = 0
        for i, count in sorted(self.Occ()[1].items()):
            first[i] = totc
            totc += count
        return first

We can implementing our class `BWA` with a simple example:

In [33]:
ref = 'atgcgtaatgccgtcgatcg'
bwa = BWA(ref)
print('Suffix Array:')
pprint(bwa.suffix_array()[0])
print("\n"+'BWT:')
pprint(bwa.bwt())
print("\n"+'Occ:')
pprint(bwa.Occ())
print("\n"+'C:')
pprint(bwa.C())

Suffix Array:
[20, 6, 16, 7, 0, 10, 18, 14, 3, 11, 19, 15, 9, 2, 4, 12, 5, 17, 13, 8, 1]

BWT:
'gtga$gttgcccttccgagaa'

Occ:
({'$': [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  'a': [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4],
  'c': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 5, 5, 5, 5, 5],
  'g': [1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 6],
  't': [0, 1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 4, 5, 5, 5, 5, 5, 5, 5, 5]},
 {'$': 1, 'a': 4, 'c': 5, 'g': 6, 't': 5})

C:
{'$': 0, 'a': 1, 'c': 5, 'g': 10, 't': 16}


## Mapping

### Exact Matching: backward search

In [34]:
class BWA:
    """ A Burrows-Wheeler Alignment class. """

    def __init__(self, reference: str):
        """ Initiation """
        self.ref = reference + "$"
        self.alphabet = sorted(['a', 'g', 't', 'c'])

    def suffix_array(self):
        """ Get the suffix array of the reference. """
        # List the all rotations with the reference input.
        rotations = []
        for i in range(len(self.ref)):
            rotations.append(self.ref[i:] + self.ref[:i])
        # Sort the rotations alphabetically
        sorted_rotations = sorted(rotations)
        # Gain the initial index of the sorted rotations
        initial_indices = np.argsort(rotations)
        return list(initial_indices), sorted_rotations

    def bwt(self):
        """ Get the Burrows–Wheeler transform array. """
        # Keep the last column in a sequence
        last = ''
        _, sorted_rotations = self.suffix_array()
        for seq in sorted_rotations:
            last += seq[-1]
        return last

    def Occ(self):
        """ Get the Occ matrix.
            Occ(c, k) is the number of occurrences of character c in the prefix L[1..k].
        """
        # Gain the bwt last column
        last_column = self.bwt()
        # Initiation
        totals = {k: 0 for k in "".join(set(last_column))}
        tally_matrix = {k: [] for k in "".join(set(last_column))}
        # Get the Occ matrix
        for i in last_column:
            totals[i] += 1
            for j in tally_matrix.keys():
                if i != j and tally_matrix[j]:
                    tally_matrix[j].append(tally_matrix[j][-1])
                elif i == j:
                    tally_matrix[j].append(totals[i])
                else:
                    tally_matrix[j].append(0)
        return tally_matrix, totals

    def C(self):
        """ Get the C table.
            C[c] is a table that, for each character c in the alphabet,
            contains the number of occurrences of lexically smaller characters in the text.
        """
        first = {}
        totc = 0
        for i, count in sorted(self.Occ()[1].items()):
            first[i] = totc
            totc += count
        return first

    def lf(self, c, i):
        """ The i-th occurrence of character c in last is the same text character
            as the i-th occurrence of c in the first.
        """
        if i < 0:
            return 0
        Occ = self.Occ()[0]
        first = self.C()
        return first[c] + Occ[c][i] - 1

    def exact_match(self, read):
        """ exact match - no indels or mismatches allowed. """
        # Get the initial low, high values
        last = self.bwt()
        low, high = last.find(read[-1]), last.rfind(read[-1])
        # Iteratively calculate low, high values
        i = len(read) - 1
        while low <= high and i >= 0:
            low = self.lf(read[i], low-1) + 1
            high = self.lf(read[i], high)
            i -= 1
        return self.suffix_array()[0][low: high+1]

In [35]:
ref = 'atgcgtaatgccgtcgatcg'
read = 'gta'
bwa = BWA(ref)
for ind in bwa.exact_match(read):
    print(ref[ind:ind+len(read)])

gta


### Inexact Matching: bounded traversal/backtracking

In [36]:
import numpy as np


class BWA:
    """ A Burrows-Wheeler Alignment class. """

    def __init__(self, reference: str):
        """ Initiation """
        self.ref = reference + "$"
        self.alphabet = sorted(['a', 'g', 't', 'c'])

    def suffix_array(self):
        """ Get the suffix array of the reference. """
        # List the all rotations with the reference input.
        rotations = []
        for i in range(len(self.ref)):
            rotations.append(self.ref[i:] + self.ref[:i])
        # Sort the rotations alphabetically
        sorted_rotations = sorted(rotations)
        # Gain the initial index of the sorted rotations
        initial_indices = np.argsort(rotations)
        return list(initial_indices), sorted_rotations

    def bwt(self):
        """ Get the Burrows–Wheeler transform array. """
        # Keep the last column in a sequence
        last = ''
        _, sorted_rotations = self.suffix_array()
        for seq in sorted_rotations:
            last += seq[-1]
        return last

    def Occ(self):
        """ Get the Occ matrix.
            Occ(c, k) is the number of occurrences of character c in the prefix L[1..k].
        """
        # Gain the bwt last column
        last_column = self.bwt()
        # Initiation
        totals = {k: 0 for k in "".join(set(last_column))}
        tally_matrix = {k: [] for k in "".join(set(last_column))}
        # Get the Occ matrix
        for i in last_column:
            totals[i] += 1
            for j in tally_matrix.keys():
                if i != j and tally_matrix[j]:
                    tally_matrix[j].append(tally_matrix[j][-1])
                elif i == j:
                    tally_matrix[j].append(totals[i])
                else:
                    tally_matrix[j].append(0)
        return tally_matrix, totals

    def C(self):
        """ Get the C table.
            C[c] is a table that, for each character c in the alphabet,
            contains the number of occurrences of lexically smaller characters in the text.
        """
        first = {}
        totc = 0
        for i, count in sorted(self.Occ()[1].items()):
            first[i] = totc
            totc += count
        return first

    def lf(self, c, i):
        """ The i-th occurrence of character c in last is the same text character
            as the i-th occurrence of c in the first.
        """
        if i < 0:
            return 0
        Occ = self.Occ()[0]
        first = self.C()
        return first[c] + Occ[c][i] - 1

    def inexact_recursion(self, low, high, mismatch_left, read, index):
        """ Recursion function for inexact match. """
        # recursion out
        # stop condition 1: entire read has been matched
        if index <= 0:
            return [(low, high)]
        # stop condition 2: the reference not contained the substrings
        if low > high:
            return []
        matches = []
        next_character = read[index-1]
        for c in self.alphabet:
            low_ = self.lf(c, low-1) + 1
            high_ = self.lf(c, high)
            # if the substring was found
            if low_ <= high_:
                # exact match
                if c == next_character:
                    matches.extend(self.inexact_recursion(low_, high_, mismatch_left, read, index-1))
                # mismatch
                elif mismatch_left > 0:
                    matches.extend(self.inexact_recursion(low_, high_, mismatch_left-1, read, index-1))
        return matches

    def inexact_match(self, read, mismatch=1):
        """ inexact match - only mismatches allowed. """
        return [self.suffix_array()[0][match[0]] for match in self.inexact_recursion(1, len(self.bwt())-1, mismatch, read, len(read))]

In [37]:
ref = 'atgcgtaatgccgtcgatcg'
read = 'gta'
bwa = BWA(ref)
matches = bwa.inexact_match(read, mismatch=1)
for match in matches:
    print(ref[match:match+len(read)])

gta
gtc


## Putting it all together

In [38]:
"""
Burrows-Wheeler Alignment
---------------------------------------------
This is a very simple implementation of a Burrows-Wheeler Aligner for indexing and sequence alignment.

Differences between this code and the real BWA algorithm:
1. It does not use array D to estimate the lower bound of the number of differences.
2. It does take insertion and deletion into consideration, only mismatch.
3. It does use a difference score to ignoring worse results.
4. It does reduce the required operating memory by storing small fractions of Occ and SA.

Author: Zell Wu
Date: 4/2/2023
"""

import numpy as np


class BWA:
    """ A Burrows-Wheeler Alignment class. """

    def __init__(self, reference: str):
        """ Initiation """
        self.ref = reference + "$"
        self.alphabet = sorted(['a', 'g', 't', 'c'])

    def suffix_array(self):
        """ Get the suffix array of the reference. """
        # List the all rotations with the reference input.
        rotations = []
        for i in range(len(self.ref)):
            rotations.append(self.ref[i:] + self.ref[:i])
        # Sort the rotations alphabetically
        sorted_rotations = sorted(rotations)
        # Gain the initial index of the sorted rotations
        initial_indices = np.argsort(rotations)
        return list(initial_indices), sorted_rotations

    def bwt(self):
        """ Get the Burrows–Wheeler transform array. """
        # Keep the last column in a sequence
        last = ''
        _, sorted_rotations = self.suffix_array()
        for seq in sorted_rotations:
            last += seq[-1]
        return last

    def Occ(self):
        """ Get the Occ matrix.
            Occ(c, k) is the number of occurrences of character c in the prefix L[1..k].
        """
        # Gain the bwt last column
        last_column = self.bwt()
        # Initiation
        totals = {k: 0 for k in "".join(set(last_column))}
        tally_matrix = {k: [] for k in "".join(set(last_column))}
        # Get the Occ matrix
        for i in last_column:
            totals[i] += 1
            for j in tally_matrix.keys():
                if i != j and tally_matrix[j]:
                    tally_matrix[j].append(tally_matrix[j][-1])
                elif i == j:
                    tally_matrix[j].append(totals[i])
                else:
                    tally_matrix[j].append(0)
        return tally_matrix, totals

    def C(self):
        """ Get the C table.
            C[c] is a table that, for each character c in the alphabet,
            contains the number of occurrences of lexically smaller characters in the text.
        """
        first = {}
        totc = 0
        for i, count in sorted(self.Occ()[1].items()):
            first[i] = totc
            totc += count
        return first

    def lf(self, c, i):
        """ The i-th occurrence of character c in last is the same text character
            as the i-th occurrence of c in the first.
        """
        if i < 0:
            return 0
        Occ = self.Occ()[0]
        first = self.C()
        return first[c] + Occ[c][i] - 1

    def exact_match(self, read):
        """ exact match - no indels or mismatches allowed. """
        # Get the initial low, high values
        last = self.bwt()
        low, high = last.find(read[-1]), last.rfind(read[-1])
        # Iteratively calculate low, high values
        i = len(read) - 1
        while low <= high and i >= 0:
            low = self.lf(read[i], low-1) + 1
            high = self.lf(read[i], high)
            i -= 1
        return self.suffix_array()[0][low: high+1]

    def inexact_recursion(self, low, high, mismatch_left, read, index):
        """ Recursion function for inexact match. """
        # recursion out
        # stop condition 1: entire read has been matched
        if index <= 0:
            return [(low, high)]
        # stop condition 2: the reference not contained the substrings
        if low > high:
            return []
        matches = []
        next_character = read[index-1]
        for c in self.alphabet:
            low_ = self.lf(c, low-1) + 1
            high_ = self.lf(c, high)
            # if the substring was found
            if low_ <= high_:
                # exact match
                if c == next_character:
                    matches.extend(self.inexact_recursion(low_, high_, mismatch_left, read, index-1))
                # mismatch
                elif mismatch_left > 0:
                    matches.extend(self.inexact_recursion(low_, high_, mismatch_left-1, read, index-1))
        return matches

    def inexact_match(self, read, mismatch=1):
        """ inexact match - only mismatches allowed. """
        return [self.suffix_array()[0][match[0]] for match in self.inexact_recursion(1, len(self.bwt())-1, mismatch, read, len(read))]

## Reference

> https://omics.sbs/blog/bwa/bwa.html \
> https://web.stanford.edu/class/cs262/archives/notes/lecture4.pdf \
> https://mr-easy.github.io/2019-12-19-burrows-wheeler-alignment-part-1 \
> https://mr-easy.github.io/2019-12-21-burrows-wheeler-alignment-part-2 \
> https://github.com/Jwomers/burrows_wheeler_alignment/blob/master/BWA.py

Thanks for the blogs and code sharing above! 