### Problem 26
Baum-Welch Learning Problem

Given: A sequence of emitted symbols x = x1 . . . xn in an alphabet A, generated by a k-state HMM with unknown transition and emission probabilities, initial Transition and Emission matrices and a number of iterations I.

Return: A matrix of transition probabilities Transition and a matrix of emission probabilities Emission that maximizes Pr(x,π) over all possible transition and emission matrices and over all hidden paths π.

In [1]:
import numpy as np
import itertools
class BaumWelch():
    def __init__(self, infile):
        '''
        contructor: saves attributes 
        
        Parameters
        ----------
            infile : file name
                
        '''
        self.file=infile
    
    def readHMM(self):
        """
        read HMM file
        
        Return
        ----------
        String:list
            the index of observable string the path emits
        state:list
            all the states
        emission:
            all the emission 
        transition.astype(float):ndarray
            transition probability
        emissionProb.astype(float):ndarray
            emission matrix
        """
        with open(self.file) as rawData:
            data=rawData.readlines()
        iteration=int(data[0].rstrip()) #the first row is the length of the path
        string=data[2].rstrip() #the 3rd row is the obserable string
        emission=data[4].rstrip().split('\t') #emission list
        state=data[6].rstrip().split('\t') #state list
        #-----------------------from 9th to 9+len(state)-1 is transition matrix----------------------#
        transitionMatrix=data[10:10+len(state)-1] 
        transitionProb=np.array([data[9].rstrip().split('\t')[1:]]) #the first row
        for i in range(len(state)-1):
            transitionProb=np.append(transitionProb,[transitionMatrix[i].rstrip().split('\t')[1:]],axis=0)
        #-----------------------from 9th to 9+len(state)-1 is transition matrix----------------------#
        
        #---------------------from 11th line on, the rest is emission probability----------------------#
        probMatrix=data[12+len(state):] 
        emissionProb=np.array([data[11+len(state)].rstrip().split('\t')[1:]]) #emission probability starts from here
        for i in range(len(state)-1):
            #we append the probability to the matrix
            emissionProb=np.append(emissionProb,[probMatrix[i].rstrip().split('\t')[1:]],axis=0) 
        #---------------------from 11th line on, the rest is emission probability----------------------#
        String=[] #stores the index of the observe
        for i in string:
            String.append(emission.index(i)) #we can now use the index to access emission matrix
        return iteration,String,state,emission,transitionProb.astype(float),emissionProb.astype(float)
    
    def forward(self,string,state,emission,transitionProb,emissionProb):
        """
        Foward algorithm
        
        Parameters
        ----------
        string:list
            the index of observable string the path emits
        state:list
            all the states
        emission:
            all the emission 
        transitionProb:ndarray
            transition probability
        emissionProb:ndarray
            emission matrix
            
        Returns
        ----------
        stringProb:float
            the probability Pr(x) that the HMM emits x.
        """
        forwardMatrix=np.empty([len(state), len(string)]) #create an empty matrix to store the node
        forwardMatrix[:,0]=np.multiply(1/len(state),np.array(emissionProb[:,string[0]])) #calcalte the fisrt column
        #-----------------------------------fill the forwardMatrix---------------------------------------#
        for col in range(1,len(string)):
            for row in range(len(state)):
                #see the formula in markdown cell
                forwardMatrix[row,col]=sum(np.multiply(forwardMatrix[:,col-1],transitionProb[:,row]))*emissionProb[row,string[col]]
        #-----------------------------------fill the forwardMatrix---------------------------------------#     
        return forwardMatrix, sum(forwardMatrix[:,-1])
    
    def backward(self,string,state,emission,transitionProb,emissionProb):
        """
        Backward Algorithm
        
        Parameters
        ----------
        string:list
            the index of observable string the path emits
        state:list
            all the states
        emission:
            all the emission 
        transitionProb:ndarray
            transition probability
        emissionProb:ndarray
            emission matrix
            
        Returns
        ----------
        stringProb:float
            the probability Pr(x) that the HMM emits x.
        """
        backwardMatrix=np.empty([len(state), len(string)]) #create an empty matrix to store the node
        backwardMatrix[:,-1]=1 #calculate the fisrt column
        #-----------------------------------fill the backwardMatrix---------------------------------------#
        for col in range(len(string)-2,-1,-1):
            for row in range(len(state)):
                backwardMatrix[row,col]=sum(np.multiply(np.multiply(backwardMatrix[:,col+1],emissionProb[:,string[col+1]]),transitionProb[row,:]))
        #-----------------------------------fill the backwardMatrix---------------------------------------#     
        return backwardMatrix
    
    def responsibilityMatrix(self,string,state,emission,transitionProb,emissionProb):
        """
        responsibilityMatrix
        
        Parameters
        ----------
        string:list
            the index of observable string the path emits
        state:list
            all the states
        emission:
            all the emission 
        transitionProb:ndarray
            transition probability
        emissionProb:ndarray
            emission matrix
            
        Returns
        ----------
        diState:list
            a list consisting of all di-state
        nodeMatrix:ndarray
            save the product of forward and backward score for every node in viterbi graph divided by Px
        edgeMatrix:ndarray
            save the product of forward and backward score for every edge in viterbi graph divided by Px
        """
        nodeMatrix=np.empty([len(state),len(string)]) #initiate the node matrix
        forwardMatrix,Px=self.forward(string,state,emission,transitionProb,emissionProb) #calculate the forward matrix
        backwardMatrix=self.backward(string,state,emission,transitionProb,emissionProb) #calculate the backward matrix
        for col in range(len(string)):
            for row in range(len(state)):
                #Px is the total probability of the string
                nodeMatrix[row,col]=forwardMatrix[row,col]*backwardMatrix[row,col]/Px 
        diState=list(itertools.product(''.join(state),repeat=2)) #get all the distate which represent an edge
        edgeMatrix=np.empty([len(diState),len(string)-1]) #initiate the edge matrix
        for col in range(len(string)-1):
            for row in range(len(diState)):
                #Px is the total probability of the string
                edgeMatrix[row,col]=forwardMatrix[row//len(state),col]*backwardMatrix[row%len(state),col+1]*transitionProb[row//len(state),row%len(state)]*emissionProb[row%len(state),string[col+1]]/Px
        return diState,nodeMatrix,edgeMatrix
    
    def BaumWelch(self,iteration,string,state,emission,transitionProb,emissionProb):
        """
        BaumWelch
        
        Parameters
        ----------
        iteration:int
            the number of iteration
        string:list
            the index of observable string the path emits
        state:list
            all the states
        emission:
            all the emission 
        transitionProb:ndarray
            transition probability
        emissionProb:ndarray
            emission matrix
            
        Returns
        ----------
        np.around(emissionProb,decimals=3):ndarray
            estimated emission matrix
        np.around(transitionProb,decimals=3):ndarray
            estimated transition matrix
        """
        #get the initiate diState,nodeMatrix,edgeMatrix
        diState,nodeMatrix,edgeMatrix=self.responsibilityMatrix(string,state,emission,transitionProb,emissionProb)
        i=0 #the number of iteration
        while i<iteration:
            for row in range(len(state)):
                for col in range(len(emission)):
                    #here we are looking for the index of column equal to the string
                    temp=[True if j==col else False for j in string] 
                    #then we sum up these columns and divide the sum of the row
                    emissionProb[row,col]=sum(nodeMatrix[row,temp])/sum(nodeMatrix[row,:])           
            edgeRow=0 #edgeRow is the row number of edgeMatrix
            for row in range(len(state)):
                for col in range(len(state)):                
                    transitionProb[row,col]=sum(edgeMatrix[edgeRow,:])/sum(sum(edgeMatrix[row*len(state):row*len(state)+len(state),:]))
                    edgeRow+=1 #change edgeRow
            #update diState,nodeMatrix,edgeMatrix
            diState,nodeMatrix,edgeMatrix=self.responsibilityMatrix(string,state,emission,transitionProb,emissionProb)
            i+=1
        return np.around(emissionProb,decimals=3),np.around(transitionProb,decimals=3)

### Main

In [2]:
def main(infile):
    '''
    Get the probability of the path here
    
    Parameters
        ----------
        infile : str 
            the filename  

        Returns
        -------
        STDOUT
    '''
    hmm=BaumWelch(infile) #instantiation
    iteration,string,state,emission,transitionProb,emissionProb=hmm.readHMM()
    estimateEmit,estimateTrans=hmm.BaumWelch(iteration,string,state,emission,transitionProb,emissionProb) 
    print('\t'+'\t'.join(state)) #column name
    for i in range(len(state)): #print the rowname and the matrix
        print(state[i]+'\t'+'\t'.join([str(j) for j in estimateTrans[i]]))      
    print('--------')#separate two matrices    
    print('\t'+'\t'.join(emission)) #column name
    for i in range(len(state)):#print the rowname and the matrix
        print(state[i]+'\t'+'\t'.join([str(j) for j in estimateEmit[i]]))

### Run the program here

In [6]:
if __name__ == "__main__":
    main('rosalind_ba10k.txt')

	A	B	C
A	0.223	0.777	0.0
B	0.0	0.449	0.551
C	0.557	0.0	0.443
--------
	x	y	z
A	0.346	0.0	0.654
B	0.0	0.757	0.243
C	0.738	0.13	0.132
