In [14]:
import numpy as np
from scipy import special

In [16]:
class Transition:
    def __init__(self, dim, alpha_0 = None, A=None):
        self.dim = dim
        if alpha_0 is  None:
            self.alpha_0 = np.ones((self.dim, self.dim))
        else:
            self.alpha_0 = alpha_0
        
        if A is None:
            b = (1 + np.random.rand(*self.alpha_0.shape))
            self.alpha = np.multiply(self.alpha_0, b)
        else:
            self.alpha = A
    
    def updateraw(self, data=None, p=None):
        if data is None:
            data = 0
        if p is None:
            p = np.ones((data.shape[2],1))
            
#         idx=find()
        idx = numpy.where(np.isnan(np.sum(np.sum(data, 1), 0)).flatten('F') == False)[0]
    
        self.alpha = self.alpha_0 + np.reshape(np.reshape(data[:,:,idx], (self.dim**2,length(idx)) )*p[idx], (self.dim, self.dim))
        
    def update(self,data=None,beta=None):
        if data is None:
            data = 0
            
        if beta is None:
            beta = 1
            
        self.alpha = self.alpha_0 * beta + np.sum(data,2)
        
    def mean(self):
        return self.alpha / np.sum(self.alpha,1)
    
    def geomean(self):
        return np.exp(self.loggeomean())
    
    def loggeomean(self):
        return special.psi(self.alpha) - special.psi(np.sum(self.alpha,1))
    
    def KLqprior(self):
        alpha_sum = np.sum(self.alpha,1) 
            
        res = special.gammaln(alpha_sum) - np.sum(special.gammaln(self.alpha),1) \
            - special.gammaln(np.sum(self.alpha_0,1)) + np.sum(special.gammaln(self.alpha_0),1) \
            + np.sum(np.multiply((self.alpha-self.alpha_0),(special.psi(self.alpha)-special.psi(alpha_sum))),1)
        
        return np.sum(res)
    
    

In [37]:
class Dirichlet:
    def __init__(dim, alpha_0 = None, slpha = None):
        self.dim = dim

        if alpha_0 is None:
            self.alpha_0 = np.ones((self.dim, 1))
        else:
            self.alpha_0 = alpha_0
        end

        if alpha is None:
            b = (1 + np.random.rand(*self.alpha_0.shape))
            self.alpha = np.multiply(self.alpha_0, b)
        else:
            self.alpha = alpha
        
    def update(self,data = None,beta = None):
        if data is None:
            data = 0
        if beta is None:
            beta = 1
            
        self.alpha = np.multiply(beta, self.alpha_0) + data
        
        if(np.sum(np.isnan(self.alpha))>0):
            print('NaNs detected')
            self.alpha=self.alpha_0

    def updateSS(self, NA=None):
        if NA is None:
            NA = 0

        self.alpha = self.alpha_0 + NA
        if(np.sum(np.isnan(self.alpha))>0):
            print('NaNs detected')
            self.alpha=self.alpha_0

    def rawupdate(self,data,p=None):
        if p is None:
            p = np.ones((1, data.shape[1]))

        idx = numpy.where(np.isnan(np.sum(data, 1)).flatten('F').reshape(-1, 1) == False)[0]
        
        SEx = data[:,idx]*p[idx].H
        self.update(SEx)

    def mean(self):
        return self.alpha / np.sum(self.alpha.flatten('F').reshape(-1, 1))

    def geomean(self):
        return np.exp(self.loggeomean())

    def loggeomean(self):
        return special.psi(self.alpha) - special.psi(np.sum(self.alpha.flatten('F').reshape(-1, 1)))
    
    
    def variance(self):
        alpha_sum = np.sum(self.alpha.flatten('F').reshape(-1, 1))
        return np.multiply(self.alpha, (alpha_sum-self.alpha)) / alpha_sum**2 / (alpha_sum-1)

    def covariance(self):
        alpha_sum = np.sum(self.alpha.flatten('F').reshape(-1, 1))
        res = - self.alpha.H * self.alpha / alpha_sum**2 / (alpha_sum-1)
        return np.multiply(res, (np.ones(self.dim)-np.eye(self.dim)))+np.diag(self.variance())

    def entropy(self):
        alpha_sum = np.sum(self.alpha)
        res = np.sum(special.gammaln(self.alpha.flatten('F').reshape(-1, 1))) - special.gammaln(alpha_sum) + \
            (alpha_sum - self.dim)*special.psi(alpha_sum) - \
            np.multiply(np.sum((self.alpha.flatten('F').reshape(-1, 1) - 1), special.psi(self.alpha.flatten('F'))))
        return res

    def KLqprior(self):
        alpha_sum = np.sum(self.alpha)
        res = special.gammaln(alpha_sum) - np.sum(special.gammaln(self.alpha)) \
            - special.gammaln(np.sum(self.alpha_0)) + np.sum(special.gammaln(self.alpha_0)) \
            + np.sum(np.multiply((self.alpha-self.alpha_0), (psi(self.alpha)-psi(alpha_sum))))
        return res

    def Eloglikelihood(self, data):
        res = data*self.loggeomean() + special.gammaln(1+np.sum(data,1)) - np.sum(special.gammaln(1 + data),1)
        res[np.isnan(res)] = 0
        return res

    def expectlogjoint(self, beta = None):
        if beta is None:
            alpha_prior = self.alpha_0
        else:
            alpha_prior = np.multiply(self.alpha_0, beta)

        contrib = np.multiply(alpha_prior - 1, self.loggeomean())

        res =  - np.sum(special.gammaln(alpha_prior.flatten('F').reshape(-1, 1))) + \
                special.gammaln(sum(alpha_prior.flatten('F').reshape(-1, 1))) + \
                np.sum(contrib.flatten('F').reshape(-1, 1))
        return res

    def lowerboundcontrib(self, beta=None):
        if beta is None:
            beta = 1
        res = self.entropy() + self.expectlogjoint(beta)
        return res

In [28]:
class HHMM:
    
    def __init__(self, Qdim, dim, D, obsTypes, Qalpha_0, Qpi0alpha_0, Aalpha_0, pi0alpha_0):
    
        self.Qdim = Qdim
        self.dim = dim
        self.D = D
        self.Q = Transition(dim, Qalpha_0)
        self.Qpi0 = Dirichlet(dim, Qpi0alpha_0)
        self.Aidx = {}
        self.A = {}
        self.pi0 = {}

        for i range(self.Qdim)
            self.Aidx[i] = np.add(list(range(1, dim+1)), (i-1) * dim)
            self.A[i]  = Transition(dim, Aalpha_0)
            self.pi0[i] = Dirichlet(dim, pi0alpha_0)
        end
        
        if(isempty(obsTypes))
            fprintf('Defaulting to normally distributed observations.\n');
            self.obsTypes{1}.dist = 'mvn';
            self.obsTypes{1}.idx = [1:D];
        else
            self.obsTypes = obsTypes;
        end

In [None]:
HHMM()