/
interface.py
36 lines (31 loc) · 1.55 KB
/
interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import HMM.hmm_for_baxter_using_only_success_trials.util as util
from scipy.special import logsumexp
import ipdb
class HmmlearnModelIncrementalLoglikCalculator(object):
def __init__(self, model):
self.model = model
self.n_components = model.n_components
self.log_transmat = util.log_mask_zero(model.transmat_)
self.log_startprob = util.log_mask_zero(model.startprob_)
self.fwdlattice = None
self.work_buffer = np.zeros(self.n_components)
def add_one_sample_and_get_loglik(self, sample):
framelogprob = self.model._compute_log_likelihood(sample)
if self.fwdlattice is None:
self.fwdlattice = np.zeros((1, self.n_components))
for i in range(self.n_components):
self.fwdlattice[0, i] = self.log_startprob[i] + framelogprob[0, i]
else:
self.fwdlattice = np.append(self.fwdlattice, np.zeros((1, self.n_components)), axis=0)
for j in range(self.n_components):
for i in range(self.n_components):
self.work_buffer[i] = self.fwdlattice[-2, i] + self.log_transmat[i, j]
self.fwdlattice[-1, j] = logsumexp(self.work_buffer) + framelogprob[0, j]
return logsumexp(self.fwdlattice[-1])
def get_calculator(model):
import hmmlearn.hmm
if issubclass(type(model), hmmlearn.hmm._BaseHMM):
return HmmlearnModelIncrementalLoglikCalculator(model)
else:
raise Exception('model of type %s is not supported by fast_log_curve_calculation.'%(type(model),))