In [25]:
import numpy as np

class Baum_welch():
    def __init__(self, pi, P, E, seq):
        self.pi = pi
        self.P = P
        self.E = E
        self.seq = seq
        self.update_dependent_variables()
    
    def update_dependent_variables(self):
        self.alphas = self.forward()
        self.betas = self.backward()
        self.normalizing = self.normalizing_factor()
        self.gammas = self.gamma()
        self.xi_all = self.xi()

    def forward(self):
        Alpha = np.zeros((self.P.shape[0], len(self.seq)))
        Alpha[:,0] = (np.diagflat(self.E[:, self.seq[0]]) @ self.pi).flatten()
        Pt = np.transpose(self.P)
        for index, observation in enumerate(self.seq[1:], 1):
            Alpha[:,index] = np.diagflat(self.E[:, observation]) @ Pt @ Alpha[:,index-1]
        return Alpha

    def backward(self):
        seq_size = len(self.seq)
        Beta = np.ones((self.P.shape[0], seq_size))
        rev = self.seq[::-1][:-1]
        for index, observation in enumerate(rev, 1):
            Beta[:,seq_size-index-1] = self.P @ np.diagflat(self.E[:, observation]) @ Beta[:,seq_size-index]
        return Beta

    def normalizing_factor(self):
        return self.alphas[:,-1].sum()

    def gamma(self):
        return (self.alphas * self.betas) / self.normalizing

    def xi(self):
        xi_all = []
        for index in range(len(self.seq[:-1])):
            xi_1 = np.zeros(self.P.shape)
            for i in range(self.P.shape[0]):
                for j in range(self.P.shape[0]):
                    c = (self.alphas[i,index] * self.P[i,j] * 
                         self.E[j,self.seq[index+1]] * self.betas[j,index+1]) / self.normalizing
                    xi_1[i][j] = c
            xi_all.append(xi_1)
        return xi_all

    def sum_seq_gammas(self, state):
        unique_observation = {}
        for num in self.seq:
            if num not in unique_observation:
                unique_observation[num] = 0
        for idx, num in enumerate(self.seq):
            unique_observation[num] += self.gammas[state][idx]
        return unique_observation

    def calculate_emission(self):
        size = self.E.shape
        new_E = np.zeros(size)
        for i in range(size[0]):
            numerator_part = self.sum_seq_gammas(i)
            denominator = self.gammas[i].sum()
            if denominator != 0:  # prevent division by zero
                for j in range(size[1]):
                    new_E[i,j] = numerator_part[j] / denominator
        return new_E

    def update_sum(self, state):
        sum_xi = 0
        sum_gamma = 0
        for i in range(len(self.seq)-1):
            sum_xi += self.xi_all[i][state[0]][state[1]]
            sum_gamma += self.gammas[state[0]][i]
        return sum_xi, sum_gamma

    def fit_once(self):
        p_size = self.P.shape
        new_P = np.zeros(p_size)
        new_pi = np.zeros(self.pi.shape)
        
        for i in range(new_pi.shape[0]):
            for j in range(new_pi.shape[1]):
                new_pi[i, j] = self.gammas[i, 0]
        
        for i in range(p_size[0]):
            for j in range(p_size[0]):
                sum_xi, sum_gamma = self.update_sum((i, j))
                new_P[i, j] = sum_xi / sum_gamma if sum_gamma != 0 else 0
        
        new_E = self.calculate_emission()
        
        return new_P, new_pi, new_E

    def fit_all(self, max_iter=1000, tolerance=1e-6):
        old_P = np.copy(self.P)
        
        for i in range(max_iter):
            print(f"\nIteration {i}:")

            
            new_P, new_pi, new_E = self.fit_once()
            
            self.P = new_P.copy()
            self.pi = new_pi.copy()
            self.E = new_E.copy()
            self.update_dependent_variables()
            print("\n New P matrix:")
            print(self.P)
            print("\n New E matrix:")
            print(self.P)
            print("\n New pi matrix:")
            print(self.P)
            print('-' * 80)
            
            diff = np.max(np.abs(self.P - old_P))
            print(f"Maximum parameter difference: {diff}")
            
            if diff < tolerance:
                print(f"\nConverged after {i+1} iterations!")
                break
        
        return self.P, self.pi, self.E

In [26]:
pi = np.matrix([[0.8],
                [0.2]])
P= np.matrix([[0.1,0.9],
             [0.5,0.5]])
E = np.matrix([[0.1,0.7,0.2],
              [0.6,0.2,0.2]])
seq = [0,2,1,1,1,2,2,2]
baum = Baum_welch(pi, P ,E ,seq)
baum.fit_all(max_iter=100)


Iteration 0:

 New P matrix:
[[0.14460942 0.85539058]
 [0.58437685 0.41562315]]

 New E matrix:
[[0.14460942 0.85539058]
 [0.58437685 0.41562315]]

 New pi matrix:
[[0.14460942 0.85539058]
 [0.58437685 0.41562315]]
--------------------------------------------------------------------------------
Maximum parameter difference: 0.08437684653819666

Iteration 1:

 New P matrix:
[[0.13970413 0.86029587]
 [0.5752851  0.4247149 ]]

 New E matrix:
[[0.13970413 0.86029587]
 [0.5752851  0.4247149 ]]

 New pi matrix:
[[0.13970413 0.86029587]
 [0.5752851  0.4247149 ]]
--------------------------------------------------------------------------------
Maximum parameter difference: 0.07528509901702801

Iteration 2:

 New P matrix:
[[0.12722712 0.87277288]
 [0.56470429 0.43529571]]

 New E matrix:
[[0.12722712 0.87277288]
 [0.56470429 0.43529571]]

 New pi matrix:
[[0.12722712 0.87277288]
 [0.56470429 0.43529571]]
--------------------------------------------------------------------------------
Maximum p

(array([[0., 1.],
        [0., 1.]]),
 array([[1.],
        [0.]]),
 array([[1.        , 0.        , 0.        ],
        [0.        , 0.42857143, 0.57142857]]))

In [16]:
new_pi = np.array([[0.43885602],
                   [0.56114398]])

# New P
new_P = np.array([[0.14460942, 0.85539058],
                  [0.58437685, 0.41562315]])

# New E
new_E = np.array([[0.13501309, 0.48074687, 0.38424003],
                  [0.11814727, 0.30262925, 0.57922348]])
baum2 = Baum_welch(new_pi, new_P ,new_E ,seq)
baum2.fit_all(max_iter=100)

[[0.14460942 0.85539058]
 [0.58437685 0.41562315]]

 Iteration:  0 
 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
New pi: 
 [[0.53013787]
 [0.46986213]] 
 --------------------------------------------------------------------------------
New P: 
 [[0.13970414 0.86029586]
 [0.5752851  0.4247149 ]] 
 --------------------------------------------------------------------------------
New E: 
 [[0.16192498 0.44718552 0.3908895 ]
 [0.09942009 0.32499323 0.57558667]] 
 --------------------------------------------------------------------------------

 Difference =  0.009091748978305003
[[0.13970414 0.86029586]
 [0.5752851  0.4247149 ]]

 Iteration:  1 
 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

(array([[0.13970414, 0.86029586],
        [0.5752851 , 0.4247149 ]]),
 array([[0.53013787],
        [0.46986213]]),
 array([[0.16192498, 0.44718552, 0.3908895 ],
        [0.09942009, 0.32499323, 0.57558667]]))