In [38]:
import numpy as np

In [39]:
def forward(model, seq):
    pi, A, B = model
    num_states = A.shape[0]
    T = seq.shape[0]
    alpha = np.empty((T, num_states)) # num_states, time
    alpha[0] = pi * B[seq[0]]
    for t in range(1, T):
        alpha[t] = alpha[t - 1] @ A * B[seq[t]]
    return alpha

def backward(model, seq):
    pi, A, B = model
    num_states = A.shape[0]
    T = seq.shape[0]
    beta = np.empty((T, num_states)) # num_states, time
    beta[T - 1] = 1
    for t in range(T - 2, -1, -1):
        beta[t] = A * B[seq[t + 1]] @ beta[t + 1]
    return beta

def alpha_prob(alpha):
    return sum(alpha[-1])

def beta_prob(pi, B, beta, seq):
    return sum(pi * B[seq[0]] * beta[0])

In [40]:
pi = np.array([0.4, 0.3, 0.3])
# A[i, j] from state i+1 to j+1
A = np.array([
    [0.8, 0.2, 0. ],
    [0.3, 0.4, 0.3],
    [0. , 0.3, 0.7]])
# B[o, i] emitting o at state i+1
B = np.array([
    [0.9, 0.5, 0.2],
    [0.1, 0.5, 0.8]])

seq = np.array([0, 0, 1, 0])
alpha = forward((pi, A, B), seq)
beta = backward((pi, A, B), seq)
alpha_prob(alpha), beta_prob(pi, B, beta, seq)

(0.054768480000000015, 0.054768480000000015)

In [132]:
def compute_xi():
    xi = np.zeros((3, 3, 3))
    for t in range(3):
        for i in range(3):
            for j in range(3):
                xi[t, i, j] = alpha[t, i] * A[i, j] * B[seq[t + 1], j] * beta[t + 1, j]
    xi /= alpha_prob(alpha)
    for t in range(3):
        assert np.isclose(xi[t].sum(), 1)
    return xi

In [138]:
def compute_xi_fast():
    xi = np.zeros((3, 3, 3))
    for t in range(3):
        xi[t] = np.outer(alpha[t], beta[t + 1]) * A * B[seq[t + 1]]
    xi /= alpha_prob(alpha)
    return xi

In [140]:
xi1 = compute_xi()
xi2 = compute_xi_fast()
np.allclose(xi1, xi2)

True

In [142]:
%timeit -n 50 compute_xi()
%timeit -n 50 compute_xi_fast()

333 µs ± 51.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
61.2 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [None]:
gamma = np.empty((4, 3))
for t in range(4):
    for i in range(3):
        gamma[t, i] = alpha[t, i] * beta[t, i]
gamma

In [None]:
gamma = alpha * beta
gamma /= alpha_prob(alpha)

In [None]:
# update pi gamma[0] equivalent
xi[0].sum(axis=1), gamma[0]

In [None]:
# update A
numerator = xi.sum(axis=0)
denom = gamma[:-1].sum(axis=0)
A = numerator / denom[:, None]
print(A.sum(axis=1)) # row-wise 1
A

In [None]:
# update B
denom = gamma.sum(axis=0) # soft counts of state i over entire seq
for k in range(2):
    B[k] = gamma[seq == k].sum(axis=0)
B /= denom
print(B.sum(axis=0)) # col-wise 1
B

In [6]:
class HMM:
    def __init__(self):
        self.pi = np.array([0.4, 0.3, 0.3])
        # A[i, j] from state i+1 to j+1
        self.A = np.array([
            [0.8, 0.2, 0. ],
            [0.3, 0.4, 0.3],
            [0. , 0.3, 0.7]])
        # B[o, i] emitting o at state i+1
        self.B = np.array([
            [0.9, 0.5, 0.2],
            [0.1, 0.5, 0.8]])
        self.num_states = 3
        self.num_emissions = 2
        self.alpha = None
        self.beta = None
        self.xi = None
        self.gamma = None
        
    def forward(self, seq):
        T = seq.shape[0]
        self.alpha = np.empty((T, self.num_states)) # T * num_states
        self.alpha[0] = self.pi * self.B[seq[0]]
        for t in range(1, T):
            self.alpha[t] = self.alpha[t - 1] @ self.A * self.B[seq[t]]
    
    def backward(self, seq):
        num_states = self.A.shape[0]
        T = seq.shape[0]
        self.beta = np.empty((T, self.num_states))
        self.beta[T - 1] = 1
        for t in range(T - 2, -1, -1):
            self.beta[t] = self.A * self.B[seq[t + 1]] @ self.beta[t + 1]
            
    def alpha_prob(self):
        return sum(self.alpha[-1])
    
    def beta_prob(self, seq):
        return sum(self.pi * self.B[seq[0]] * self.beta[0])
    
    def expectation(self, seq):
        self.forward(seq)
        self.backward(seq)
        
        T = seq.shape[0]
        alpha_p = self.alpha_prob()
        self.xi = np.empty((T - 1, self.num_states, self.num_states))
        for t in range(T - 1):
            for i in range(self.num_states):
                for j in range(self.num_states):
                    numerator = self.alpha[t, i] * self.A[i, j] * \
                    self.B[seq[t + 1], j] * self.beta[t + 1, j]
                    self.xi[t, i, j] = numerator
        self.xi /= self.alpha_prob()
        
        self.gamma = self.alpha * self.beta / alpha_p
    
    def maximization(self, seq):
        # update pi
        self.pi = self.gamma[0] # soft counts of each state at time 1
        # update A
        numerator = self.xi.sum(axis=0) # soft counts of transitions from i to j
        denom = self.gamma[:-1].sum(axis=0) # soft counts of transitions out of i 
        self.A = numerator / denom[:, None]
        
        # update B
        denom = self.gamma.sum(axis=0) # soft counts of state i over entire seq
        for k in range(self.num_emissions):
            self.B[k] = self.gamma[seq == k].sum(axis=0)
        self.B /= denom
        
    def check_probabilities(self):
        assert np.isclose(hmm.pi.sum(), 1)
        assert np.allclose(hmm.A.sum(axis=1), 1)
        assert np.allclose(hmm.B.sum(axis=0), 1)

In [15]:
seq = np.array([0, 0, 1, 0])
hmm = HMM()
hmm.forward(seq)
hmm.backward(seq)
hmm.alpha_prob(), hmm.beta_prob(seq)

(0.054768480000000015, 0.054768480000000015)

In [16]:
np.set_printoptions(precision=5, suppress=True)
for i in range(10):
    hmm.expectation(seq)
    hmm.maximization(seq)
    print(hmm.A, hmm.B, hmm.pi, sep='\n\n**\n\n')
    hmm.check_probabilities()

[[0.72932 0.27068 0.     ]
 [0.36632 0.44105 0.19264]
 [0.      0.45664 0.54336]]

**

[[0.8303  0.63726 0.64754]
 [0.1697  0.36274 0.35246]]

**

[0.69289 0.23711 0.07   ]
[[0.69354 0.30646 0.     ]
 [0.35507 0.44931 0.19563]
 [0.      0.45448 0.54552]]

**

[[0.81489 0.63966 0.72834]
 [0.18511 0.36034 0.27166]]

**

[0.74027 0.19774 0.06199]
[[0.66598 0.33402 0.     ]
 [0.34553 0.45811 0.19636]
 [0.      0.4542  0.5458 ]]

**

[[0.82017 0.62388 0.75873]
 [0.17983 0.37612 0.24127]]

**

[0.77552 0.16377 0.0607 ]
[[0.6368  0.3632  0.     ]
 [0.33703 0.46772 0.19525]
 [0.      0.45552 0.54448]]

**

[[0.83386 0.59847 0.78089]
 [0.16614 0.40153 0.21911]]

**

[0.80896 0.13019 0.06085]
[[0.60291 0.39709 0.     ]
 [0.33004 0.4767  0.19326]
 [0.      0.45655 0.54345]]

**

[[0.85314 0.56719 0.80466]
 [0.14686 0.43281 0.19534]]

**

[0.84166 0.09699 0.06135]
[[0.56374 0.43626 0.     ]
 [0.32562 0.48298 0.1914 ]
 [0.      0.45667 0.54333]]

**

[[0.8771  0.53285 0.83139]
 [0.1229  0.46715 0.1