In [1]:
import pandas as pd
import numpy as np

import matplotlib
from matplotlib import pyplot as plt


#states = PATH1, PATH2, PATH3 = 0, 1, 2

#V = np.array([0, 1, 2, 2, 1, 0, 0, 1, 2, 0])
V = np.array([0, 0, 1, 1, 2, 2])

# Transition Probabilities
#A = np.array(((0.571, 0.179, 0.250), (0.401, 0.528, 0.071), (0.561, 0.160, 0.279)))
A = np.array(((0.682, 0.068, 0.250), (0.059, 0.759, 0.182), (0.214, 0.155, 0.631)))

# Emission Probabilities
#B = np.array(((0.625, 0.125, 0.25), (0.028, 0.805, 0.167), (1/8, 1/16, 13/16)))
#B = np.array(((5./7., 1./7., 1./7.), (1./25., 24./25., 0.), (1./25., 0., 24./25.)))
B = np.array(((5./8., 1./8., 1./4.), (1./61., 157./192., 1./6.), (1./16., 1./36., 131./144.)))

# Equal Probabilities for the initial distribution
#pi = np.array((0.303, 0.341, 0.356))
pi = np.array((0.280, 0.314, 0.406))

def forward(V, A, B, pi):
    global alpha
    alpha = np.zeros((V.shape[0], A.shape[0]))
    alpha[0, :] = pi * B[:, V[0]]

    for t in range(1, V.shape[0]):
        for j in range(A.shape[0]):

            alpha[t, j] = alpha[t - 1].dot(A[:, j]) * B[j, V[t]]

    return alpha


def backward(V, A, B):
    global beta
    beta = np.zeros((V.shape[0], A.shape[0]))

    # setting beta(T) = 1
    beta[V.shape[0] - 1] = np.ones((A.shape[0]))

    # Loop in backward way from T-1 to
    # Due to python indexing the actual loop will be T-2 to 0
    for t in range(V.shape[0] - 2, -1, -1):
        for j in range(A.shape[0]):
            beta[t, j] = (beta[t + 1] * B[:, V[t + 1]]).dot(A[j, :])

    return beta


def baum_welch(V, A, B, pi, n_iter=100):
    global gamma, numerator, denominator, xi
    M = A.shape[0]
    T = len(V)

    for n in range(n_iter):
        alpha = forward(V, A, B, pi)
        beta = backward(V, A, B)

        xi = np.zeros((M, M, T - 1))
        for t in range(T - 1):
            denominator = np.dot(np.dot(alpha[t, :].T, A) * B[:, V[t + 1]].T, beta[t + 1, :])
            for i in range(M):
                numerator = alpha[t, i] * A[i, :] * B[:, V[t + 1]].T * beta[t + 1, :].T
                xi[i, :, t] = numerator / denominator

        gamma = np.sum(xi, axis=1)
        A = np.sum(xi, 2) / np.sum(gamma, axis=1).reshape((-1, 1))

        # Add additional T'th element in gamma
        gamma = np.hstack((gamma, np.sum(xi[:, :, T - 2], axis=0).reshape((-1, 1))))

        K = B.shape[1]
        denominator = np.sum(gamma, axis=1)
        for l in range(K):
            B[:, l] = np.sum(gamma[:, V == l], axis=1)

        B = np.divide(B, denominator.reshape((-1, 1)))
    return {"A_hat":A, "B_hat":B, "Pi_hat": gamma[:,0]}

In [2]:
baum_welch(V, A, B, pi, n_iter=1)

{'A_hat': array([[0.59014994, 0.29080349, 0.11904657],
        [0.03159304, 0.57677031, 0.39163665],
        [0.11832686, 0.12678427, 0.75488888]]),
 'B_hat': array([[0.72925127, 0.19087333, 0.07987539],
        [0.02034277, 0.80770864, 0.1719486 ],
        [0.08040263, 0.04357981, 0.87601756]]),
 'Pi_hat': array([0.9273805 , 0.00878841, 0.06383108])}

In [49]:
baum_welch(V, A, B, pi, n_iter=2)

{'A_hat': array([[0.87748492, 0.01295527, 0.1095598 ],
        [0.91165126, 0.05360651, 0.03474223],
        [0.88492554, 0.01770503, 0.09736943]]),
 'B_hat': array([[0.41014493, 0.34136941, 0.24848565],
        [0.78979552, 0.21020448, 0.        ],
        [0.10654666, 0.        , 0.89345334]]),
 'Pi_hat': array([0.53475317, 0.39416723, 0.0710796 ])}

In [3]:
baum_welch(V, A, B, pi, n_iter=5)

{'A_hat': array([[8.21783182e-01, 1.45954055e-02, 1.63621412e-01],
        [9.98560495e-01, 1.43171283e-03, 7.79208236e-06],
        [9.15074687e-01, 1.47527742e-02, 7.01725392e-02]]),
 'B_hat': array([[0.35848839, 0.3826689 , 0.2588427 ],
        [0.99881539, 0.00118461, 0.        ],
        [0.29252822, 0.        , 0.70747178]]),
 'Pi_hat': array([0.18747245, 0.66804596, 0.14448159])}

In [29]:
baum_welch(V, A, B, pi, n_iter=10)

{'A_hat': array([[9.99999987e-001, 6.71847722e-014, 1.31617802e-008],
        [1.00000000e+000, 2.19211172e-115, 7.69527435e-118],
        [1.00000000e+000, 9.01932542e-074, 7.47681502e-070]]),
 'B_hat': array([[3.42803960e-001, 3.28598020e-001, 3.28598020e-001],
        [1.00000000e+000, 9.72262592e-116, 0.00000000e+000],
        [1.00000000e+000, 0.00000000e+000, 9.81666974e-064]]),
 'Pi_hat': array([0.12969602, 0.42578717, 0.44451681])}

In [31]:
baum_welch(V, A, B, pi, n_iter=50)


{'A_hat': array([[9.99999906e-01, 1.90801523e-13, 9.40088605e-08],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.34280391, 0.32859804, 0.32859804],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.129696  , 0.42578718, 0.44451682])}

In [27]:
baum_welch(V, A, B, pi, n_iter=100)

{'A_hat': array([[9.99973281e-01, 3.41971526e-10, 2.67188056e-05],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.34278777, 0.32860611, 0.32860611],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.12968827, 0.42579096, 0.44452077])}

In [26]:
baum_welch(V, A, B, pi, n_iter=500)

{'A_hat': array([[9.13783677e-01, 1.74376327e-06, 8.62145789e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [25]:
baum_welch(V, A, B, pi, n_iter=1000)

{'A_hat': array([[9.13783677e-01, 2.74981838e-06, 8.62135728e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [24]:
baum_welch(V, A, B, pi, n_iter=1500)

{'A_hat': array([[9.13783677e-01, 4.32349122e-06, 8.62119992e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [23]:
baum_welch(V, A, B, pi, n_iter=2000)

{'A_hat': array([[9.13783677e-01, 6.76965149e-06, 8.62095530e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [22]:
baum_welch(V, A, B, pi, n_iter=2500)

{'A_hat': array([[9.13783677e-01, 1.05390393e-05, 8.62057836e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [21]:
baum_welch(V, A, B, pi, n_iter=3000)

{'A_hat': array([[9.13783677e-01, 1.62782094e-05, 8.62000444e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [20]:
baum_welch(V, A, B, pi, n_iter=4000)

{'A_hat': array([[9.13783677e-01, 2.48755414e-05, 8.61914471e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [19]:
baum_welch(V, A, B, pi, n_iter=5000)

{'A_hat': array([[9.13783677e-01, 3.74788927e-05, 8.61788438e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [33]:
baum_welch(V, A, B, pi, n_iter=10000)

{'A_hat': array([[9.13783677e-01, 6.94707484e-08, 8.62162532e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [34]:
baum_welch(V, A, B, pi, n_iter=15000)

{'A_hat': array([[9.13783677e-01, 4.37589998e-08, 8.62162789e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [39]:
baum_welch(V, A, B, pi, n_iter=20000)

{'A_hat': array([[9.13783677e-01, 4.33340847e-09, 8.62163183e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [36]:
baum_welch(V, A, B, pi, n_iter=25000)

{'A_hat': array([[9.13783677e-01, 1.73559696e-08, 8.62163053e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [37]:
baum_welch(V, A, B, pi, n_iter=50000)

{'A_hat': array([[9.13783677e-01, 1.09293165e-08, 8.62163117e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}

In [38]:
baum_welch(V, A, B, pi, n_iter=100000)

{'A_hat': array([[9.13783677e-01, 6.88205310e-09, 8.62163158e-02],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 'B_hat': array([[0.28912734, 0.35543633, 0.35543633],
        [1.        , 0.        , 0.        ],
        [1.        , 0.        , 0.        ]]),
 'Pi_hat': array([0.10302077, 0.43883776, 0.45814147])}