# ДЗ № 14, Волжина Лена

Реализуйте алгоритм Forward-Backward для выравнивания двух последовательностей. [Задание](https://compscicenter.ru/learning/assignments/27582/)

In [1]:
from collections import defaultdict
from itertools import product


![FB_alignment](hw14_algo.png)

In [31]:
d = d1 = d2 = 1 / 3  # can start from any point
r1 = r2 = 0  # no different gaps in row
t = 0.2  # has it some meaning?
e = e1 = e2 = 1 / 4

In [32]:
transitions = {
    'M': {'X': d1, 'Y': d2, 'M': 1 - d1 - d2 - t, 'E': t},
    'X': {'X': e1, 'Y': r2, 'M': 1 - e1 - r2 - t, 'E': t},
    'Y': {'X': r1, 'Y': e2, 'M': 1 - e2 - r1 - t, 'E': t},
    'E': {'X': 0,  'Y': 0,  'M': 0, 'E': 0},
}
pi = {'X': d1, 'Y': d2, 'M': 1 - d1 - d2}

chars = 'ATGC'
alpha = 1/8      # probability of some mismatch
pam = {(c1, c2): 1 - alpha * (len(chars) - 1) if c1 == c2 else alpha   # match or mismatch
       for c1 in chars for c2 in chars}
gaps = {c: 1 / len(chars) for c in chars}

s1, s2 = 'AGA', 'AGAGA'

In [46]:
class FBAligner(object):
    def __init__(self, transition, match, gap):
        self.transition = transition
        self.match = match
        self.gap = gap
        self.states = ['M', 'X', 'Y']

    def calculate_alphas(self, s1, s2):
        alphas = defaultdict(dict)
        n, m = len(s1), len(s2)
        
        # init alphas
        alphas[(0, 0)] = {'M': 1, 'X': 0, 'Y': 0}
        for i in range(n): 
            alphas[(i, -1)] = {state: 0 for state in self.states}
        for j in range(m): 
            alphas[(-1, j)] = {state: 0 for state in self.states}
        
        def process(i, j):
            #print(i, j)
            c1, c2 = s1[i], s2[j]
            pij = self.match[(c1, c2)]
            qi, qj = self.gap[c1], self.gap[c2]
            if i == 0 and j == 2:
                print(c1, c2, pij, qi, qj)
            
            alphas[(i, j)] = {
                'M': pij * ((1 - 2 * d - t) * alphas[(i - 1, j - 1)]['M'] + 
                            (1 - e - t) * (alphas[(i - 1, j - 1)]['X'] + 
                                           alphas[(i - 1, j - 1)]['Y'])),
                'X': qi * (d * alphas[(i - 1, j)]['M'] + e * alphas[(i - 1, j)]['X']),
                'Y': qj * (d * alphas[(i, j - 1)]['M'] + e * alphas[(i, j - 1)]['Y']) #???
            }
        
        # calculate alphas
        for k in range(min(n, m)):
            #print('alphas, k=', k)
            if k != 0:
                process(k, k)
            for n_row in range(k + 1, n):
                process(n_row, k)
            for n_col in range(k + 1, m):
                process(k, n_col)            
        
        return alphas
       
    def calculate_betas(self, s1, s2):
        betas = defaultdict(dict)
        n, m = len(s1), len(s2)
        
        # init betas
        betas[(n - 1, m - 1)] = {state: t for state in self.states}
        for i in range(n): 
            betas[(i, m)] = {state: 0 for state in self.states}
        for j in range(m): 
            betas[(n, j)] = {state: 0 for state in self.states}
            
        def process(i, j):
            #print(i, j)
            get_or_None = lambda xs, idx: xs[idx] if len(xs) > idx else None
            c1, c2 = get_or_None(s1, i + 1), get_or_None(s2, j + 1)
            pij = self.match.get((c1, c2), 100500)
            qi, qj = self.gap.get(c1, 3030), self.gap.get(c2, 4040)
            
            betas[(i, j)] = {
                'M': ((1 - 2 * d - t) * pij * betas[(i + 1, j + 1)]['M'] +
                      d * (qi * betas[(i + 1, j)]['X'] + qj * betas[(i, j + 1)]['Y'])),
                'X': ((1 - e - t) * pij * betas[(i + 1, j + 1)]['M'] +
                      e * qi * betas[(i + 1, j)]['X']),
                'Y': ((1 - e - t) * pij * betas[(i + 1, j + 1)]['M'] +
                      e * qj * betas[(i, j + 1)]['Y']),
            }
        
        # calculate betas
        for k in range(min(n, m)):
            #print('betas, k=', k)
            if k != 0:
                process(n - k - 1, m - k - 1)
            for n_row in reversed(range(n - k - 1)):
                process(n_row, m - k - 1)
            for n_col in reversed(range(m - k - 1)):
                process(n - k - 1, n_col)
        return betas
        
    def process(self, s1, s2):
        n, m = len(s1), len(s2)
        alphas = self.calculate_alphas(s1, s2)
        betas = self.calculate_betas(s1, s2)
        
        # calculate alignment probabilities
        p_sum_fwd = sum(alphas[(n - 1, m - 1)].values()) * t
        p_sum_bwd = 0
        c1, c2 = s1[0], s2[0]
        pij = self.match[(c1, c2)]
        qi, qj = self.gap[c1], self.gap[c2]
        for state in self.states:
            p_open = pij if state == 'M' else (qi if state == 'X' else qj)
            p_sum_bwd += betas[(0, 0)][state] * (1 / 3) * p_open
        
        print(p_sum_fwd, p_sum_bwd)
        
        # calculate result
        result = []
        for i in range(n):
            row = [alphas[(i, j)]['M'] * betas[(i, j)]['M'] / p_sum_fwd
                   for j in range(m)]
            result.append(row)
            
        return result

In [47]:
aligner = FBAligner(transitions, pam, gaps)
alphas, betas = aligner.calculate_alphas(s1, s2), aligner.calculate_betas(s1, s2)
res = aligner.process(s1, s2)

print(' ' * 8 + (' ' * 9).join(s2))
for i in range(len(s1)):
    print(s1[i], end='    ')
    print('    '.join('{:.4f}'.format(v) for v in res[i]))

A A 0.625 0.25 0.25
A A 0.625 0.25 0.25
0.00011003056278935187 5.7943932804060574e-05
        A         G         A         G         A
A    1.0000    0.0000    0.0000    0.0000    0.0000
G    0.0000    0.4093    0.3128    0.2712    0.0034
A    0.0000    0.0034    0.0657    0.0868    0.8407


In [43]:
alphas
for state in ['X', 'M', 'Y']:
    print('state', state)
    print(' ' * 8 + (' ' * 9).join(s2))
    for i in range(len(s1)):
        print(s1[i], end='    ')
        print('    '.join('{:.4f}'.format(alphas[(i, j)][state]) for j in range(len(s2))))
              

state X
        A         G         A         G         A
A    0.0000    0.0000    0.0000    0.0000    0.0000
G    0.0833    0.0000    0.0000    0.0000    0.0000
A    0.0052    0.0069    0.0005    0.0001    0.0000
state M
        A         G         A         G         A
A    1.0000    0.0000    0.0000    0.0000    0.0000
G    0.0000    0.0833    0.0057    0.0018    0.0000
A    0.0000    0.0057    0.0069    0.0006    0.0005
state Y
        A         G         A         G         A
A    0.0000    0.0833    0.0052    0.0003    0.0000
G    0.0000    0.0000    0.0069    0.0009    0.0002
A    0.0000    0.0000    0.0005    0.0006    0.0001


In [22]:
betas

defaultdict(dict,
            {(0, 0): {'M': 0.00010602710865162035,
              'X': 0.0001399019029405382,
              'Y': 0.00018147023518880206},
             (0, 1): {'M': 0.0006206958912037037,
              'X': 0.0003080376519097222,
              'Y': 0.0006677381727430555},
             (0, 2): {'M': 0.0022395833333333334,
              'X': 0.005966796875,
              'Y': 0.005966796875},
             (0, 3): {'M': 0.003628472222222222,
              'X': 0.0032161458333333334,
              'Y': 0.0006770833333333334},
             (0, 4): {'M': 0.0005208333333333333, 'X': 0.000390625, 'Y': 0.0},
             (0, 5): {'M': 0, 'X': 0, 'Y': 0},
             (1, 0): {'M': 3.5332573784722216e-05,
              'X': 2.6448567708333334e-06,
              'Y': 2.8432210286458333e-05},
             (1, 1): {'M': 0.0003439670138888889,
              'X': 0.00021158854166666667,
              'Y': 0.00041259765625},
             (1, 2): {'M': 0.003628472222222222,
           