## Task ba10k

In [1]:
import numpy as np
import copy

In [21]:
def read_data(fname='rosalind_ba10k.txt'):
    with open(fname, 'r') as f:
        it = int(f.readline().strip())
        _ = f.readline()
        x = f.readline().strip()
        _ = f.readline()
        alphabet = f.readline().strip().split()
        _ = f.readline()
        states = f.readline().strip().split()
        _ = f.readline()
        _ = f.readline()
        transition = [0]*len(states)
        for i in range(len(states)):
            row = f.readline().strip().split()
            state = row[0]
            prob_list = [float(x) for x in row[1:]]
            transition[i] = np.array(prob_list)
        _ = f.readline()
        _ = f.readline()
        emission = [0]*len(states)
        for i in range(len(states)):
            row = f.readline().strip().split()
            state = row[0]
            prob_list = [float(x) for x in row[1:]]
            emission[i] = np.array(prob_list)
    return it, x, alphabet, states, np.matrix(transition), np.matrix(emission)

In [90]:
class hmm:
    def __init__(self, states, alphabet, transition, emission):
        self.states = states
        self.alphabet = alphabet
        self.start = np.matrix([1/len(states) for _ in range(len(states))])
        self.transition = transition
        self.emission = emission
        alpha_map = {}
        
        for i, obs in enumerate(self.alphabet):
            alpha_map[obs] = i
        
        states_map = {}
        for i, obs in enumerate(self.states):
            states_map[i] = obs
            
        self.alpha_map = alpha_map
        self.states_map = states_map
 
    def get_forward_prob(self,x):
        n_states = self.emission.shape[0]
        x_len = len(x)

        x_idx = self.alpha_map[x[0]]
        fw_prob = np.asmatrix(np.zeros((n_states, x_len)))
        scale_mtx = np.asmatrix(np.zeros((x_len, 1)))

        fw_prob[:, 0] = (self.emission[:, x_idx].T * self.start.T).T
        fw_prob[:, 0] = np.multiply(self.emission[:, x_idx].T, self.start).T
        
        scale_mtx[0, 0] = 1/np.sum(fw_prob[:, 0])
        fw_prob[:, 0] = fw_prob[:, 0] * scale_mtx[0]

        for t in range(1,x_len):
            x_idx = self.alpha_map[x[t]]
            fw_prob[:, t] = (fw_prob[:, t-1].T * self.transition).T            
            fw_prob[:, t] = np.multiply(fw_prob[:,t].T, self.emission[:,x_idx].T).T

            scale_mtx[t] = 1 / np.sum(fw_prob[:,t])
            fw_prob[:, t] = fw_prob[:, t] * scale_mtx[t]
        return (fw_prob, scale_mtx)
    
    def get_backward_prob(self,x,scale_mtx):
        n_states = self.emission.shape[0]
        x_len = len(x)
        
        x_idx = self.alpha_map[x[x_len-1]]
        bw_prob = np.asmatrix(np.zeros((n_states, x_len)))

        bw_prob[:, x_len-1] = scale_mtx[x_len-1]

        for t in range(x_len-1, 0,-1):
            x_idx = self.alpha_map[x[t]]
            bw_prob[:, t-1] = np.multiply(bw_prob[:,t], self.emission[:,x_idx])
            bw_prob[:, t-1] = self.transition * bw_prob[:,t-1]
            bw_prob[:, t-1] = np.multiply(bw_prob[:, t-1], scale_mtx[t-1])

        return bw_prob
        
    def update_emission(self,x):
        new_em = np.asmatrix(np.zeros(self.emission.shape))
        n_states = self.emission.shape[0]
        x_len = len(x)

        selectCols=[]
        for i in range(self.emission.shape[1]):
            selectCols.append([])
        for i in range(len(x)):
            selectCols[self.alpha_map[x[i]]].append(i)

        fw_prob, scale_mtx = self.get_forward_prob(x)
        bw_prob = self.get_backward_prob(x, scale_mtx)

        prob_obs_seq = np.sum(fw_prob[:, x_len-1])

        delta = np.multiply(fw_prob, bw_prob) / prob_obs_seq 
        delta = delta / scale_mtx.T
    
        totalProb = np.sum(delta, axis=1)

        for i in range(self.emission.shape[0]):
            for j in range(self.emission.shape[1]):
                new_em[i,j] = np.sum(delta[i, selectCols[j]]) / totalProb[i]
        return new_em
    
    def update_transition(self,x):
        new_trans = np.asmatrix(np.zeros(self.transition.shape))

        fw_prob, scale_mtx = self.get_forward_prob(x)
        bw_prob = self.get_backward_prob(x, scale_mtx)

        for t in range(len(x)-1):
            temp1 = np.multiply(fw_prob[:, t], bw_prob[:, t+1].T)
            temp1 = np.multiply(self.transition,temp1)
            new_trans = new_trans + np.multiply(temp1,
                                                self.emission[:, self.alpha_map[x[t+1]]].T)

        for i in range(self.transition.shape[0]):
            new_trans[i, :] = new_trans[i, :] / np.sum(new_trans[i,:])

        return new_trans
    
    def get_baum_welch_params(self, x, iterations):
        for i in range(iterations):
            emProbNew = np.matrix(np.zeros((self.emission.shape)))
            transProbNew = np.matrix(np.zeros((self.transition.shape)))

            emProbNew= emProbNew + self.update_emission(x)
            transProbNew = transProbNew + self.update_transition(x)

            em_norm = emProbNew.sum(axis=1)
            trans_norm = transProbNew.sum(axis=1)

            emProbNew = emProbNew / em_norm
            transProbNew = transProbNew / trans_norm.T

            self.emission, self.transition = emProbNew, transProbNew

        return self.transition, self.emission

In [99]:
def print_res(res, states, alphabet):
    print('   '.join(states))
    trans = np.array(res[0])
    em = np.array(res[1])

    for i, state in enumerate(states):
        row = trans[i]
        print(state + '	' + '	'.join(
            ['{0:.3f}'.format(round(elem, 3)) for elem in np.array(row)]
        ))
    print('--------')
    print('	'+'	'.join(alphabet))
    for i, state in enumerate(states):
        row = em[i]
        print(state + '	' + '	'.join(
            ['{0:.3f}'.format(round(elem, 3)) for elem in np.array(row)]
        ))

In [100]:
def main(fname='rosalind_ba10k.txt'):
    it, x, alphabet, states, transition, emission = read_data(fname)
#     print(x, alphabet, states, transition, emission)
    model = hmm(states, alphabet, transition, emission)
    res = model.get_baum_welch_params(x, it)
    print_res(res, states, alphabet)
#     return res


In [101]:
res = main(fname='sample_data/ba10k/sample.txt')

A   B
A	0.000	1.000
B	0.786	0.214
--------
	x	y	z
A	0.242	0.000	0.758
B	0.172	0.828	0.000


In [102]:
main('sample_data/ba10k/input.txt')

A   B   C   D
A	0.000	0.384	0.186	0.430
B	0.989	0.000	0.000	0.011
C	0.000	0.000	0.596	0.404
D	0.403	0.596	0.000	0.001
--------
	x	y	z
A	0.447	0.000	0.553
B	0.032	0.648	0.321
C	0.000	1.000	0.000
D	0.714	0.000	0.286


In [105]:
# driver code
# main()