In [1]:
%matplotlib inline

import numpy as np
from pyxhmm import baum_welch

np.set_printoptions(formatter={'float': '{: 0.3f}'.format})

In [4]:
%%time

# setup toy problem
def gen_p_matrix(x0, x1):
    p = np.random.rand(x0, x1) + 1e-17
    return (p.T / p.sum(axis=1)).T

def gen_data(i_state, tau, epsilon, n_data):
    n_states, n_emissions = tau.shape[0], epsilon.shape[1]
    data, states = np.zeros(n_data, dtype=np.int64), np.zeros(n_data, dtype=np.int64)
    for i in range(n_data):
        # sample emission
        states[i] = i_state
        data[i] = np.random.choice(n_emissions, p=epsilon[i_state])
        # change state
        i_state = np.random.choice(n_states, p=tau[i_state])
    return data, states

# setup i_state, transition, and emission matrix
n_states, n_emissions, i_state = 3, 4, 0
tau = np.array([
    [0.8, 0.2, 0.0],
    [0.0, 0.8, 0.2],
    [0.2, 0.0, 0.8]], dtype = np.float64)
epsilon = np.array([
    [0.8, 0.1, 0.05, 0.05],
    [0.05, 0.8, 0.1, 0.05],
    [0.05, 0.05, 0.8, 0.1]], dtype = np.float64)
# tau, epsilon = gen_p_matrix(n_states, n_states), gen_p_matrix(n_states, n_emissions)

print(f'n_states: {n_states}, n_emissions: {n_emissions}, i_state:{i_state}')
print('\ntau:\n', tau, '\nepsilon:\n', epsilon)

# generate dataset
data, states = gen_data(i_state, tau, epsilon, 100000)
print('\ndata:\n', data, '\nstates:\n', states)

n_states: 3, n_emissions: 4, i_state:0

tau:
 [[ 0.800  0.200  0.000]
 [ 0.000  0.800  0.200]
 [ 0.200  0.000  0.800]] 
epsilon:
 [[ 0.800  0.100  0.050  0.050]
 [ 0.050  0.800  0.100  0.050]
 [ 0.050  0.050  0.800  0.100]]

data:
 [0 0 0 ... 2 2 2] 
states:
 [0 0 0 ... 2 2 2]
CPU times: user 3.65 s, sys: 3.32 ms, total: 3.65 s
Wall time: 3.65 s


In [5]:
%%time

# attempt to recover pi, tau, and epsilon from data (n_states, n_emissions are given)
_tau, _epsilon = gen_p_matrix(n_states, n_states), gen_p_matrix(n_states, n_emissions)
_pi = np.ones(n_states)
_pi /= _pi.sum()
alpha, beta, gamma, delta, ksi, zeta = baum_welch(_pi, _tau, _epsilon, data, 1000)

print('\n_pi:\n', _pi, '\n_tau:\n', _tau, '\n_epsilon:\n', _epsilon)


_pi:
 [ 0.000  1.000  0.000] 
_tau:
 [[ 0.795  0.000  0.205]
 [ 0.200  0.800  0.000]
 [ 0.000  0.199  0.801]] 
_epsilon:
 [[ 0.048  0.800  0.105  0.048]
 [ 0.797  0.101  0.051  0.051]
 [ 0.054  0.050  0.794  0.103]]
CPU times: user 10.7 s, sys: 6.67 ms, total: 10.7 s
Wall time: 10.7 s
