In [525]:
import scipy.io as sio
import scipy.stats as stats
import numpy as np
from numpy.linalg import *
import pandas as pd

In [526]:
def Single_Gaussian(x, mu, var):
    x_mu = x-mu
    b = x_mu**2 / var
    N = np.exp(-b/2)/ np.sqrt(var*2*np.pi)
    return N

In [527]:
def E_step(data , m , k , d , model):
    a = np.zeros((m,k,d))
    b = np.zeros((m,k,d))
    for j in range(k):
        a[:,j,:] = model['rho'] * Single_Gaussian(data,model['theta_mu'][j,:],model['theta_var'][j,:])
        b[:,j,:] = (1-model['rho']) * Single_Gaussian(data , model['lam_mu'] , model['lam_var'])
    c = a+b
    
    yu = c.prod(2)*model['alpha']
    w = yu / yu.sum(1).reshape(m,1)
    c[c==0] = 1
    w[np.isnan(w)] = 1

    u = np.zeros((m,k,d))
    v = np.zeros((m,k,d))
    for j in range(k):
        u[:,j,:] = (a[:,j,:]/c[:,j,:]) * w[:,j].reshape(m,1)
        v[:,j,:] = -u[:,j,:] + w[:,j].reshape(m,1)
    return w,u,v

In [528]:
def M_step(data, w, u, v, m, k, d ):
    model = {}
    #alpha = w.sum(0)/m
    yu = np.maximum(w.sum(0)-d,0)
    alpha = yu / yu.sum()

    #rho = u.sum(1).sum(0)/m
    rho = (np.maximum(u.sum(1).sum(0)-k,0)) / (np.maximum(u.sum(1).sum(0)-k,0) + np.maximum(v.sum(1).sum(0)-1,0))

    theta_mu = np.zeros((k,d))
    theta_var = np.zeros((k,d))
    lam_mu = np.zeros(d)
    lam_var = np.zeros(d)
    
    for l in range(d):
        for j in range(k):
            if (u[:,j,l].sum()==0):
                alpha[j]=0
            theta_mu[j,l] = np.dot(u[:,j,l].T , data[:,l])/u[:,j,l].sum()
            theta_var[j,l] = np.dot(u[:,j,l].T,((data[:,l]-theta_mu[j,l])**2))/u[:,j,l].sum()
    theta_var[theta_var==0] = 0.0001
    
    
    for l in range(d):
        e = 0
        for i in range(m):
            e = e + v[i,:,l].sum() * data[i,l]
        f = v[:,:,l]
        lam_mu[l] = e/ f.sum()
        if(f.sum()==0):
            lam_mu[l] = (e+0.0001)/(f.sum()+0.0001)
            rho[l] = rho[l]*0.9
    
    for l in range(d):
        e = 0
        for i in range(m):
            e = e + v[i,:,l].sum() * ((data[i,l]-lam_mu[l])**2)
        f = v[:,:,l]
        lam_var[l] = e/ f.sum()
        if(f.sum()==0 or e==0):
            lam_var[l] = (e+0.0001)/(f.sum()+0.0001)
            rho[l] = rho[l]*0.9
    
    #theta_mu = (u*data.reshape(m,1,d)).sum(0)/u.sum(0)
    #theta_var = np.zeros((k,d))
    #for j in range(k):
    #    theta_var[j,:] = (u[:,j,:] * ((data-theta_mu[j:j+1,:])**2)).sum(0)/u[:,j,:].sum(0)
    #    
    #theta_var[theta_var==0] = 0.0001
    #yu = v.sum(1).sum(0)
    #rho[yu==0] = rho[yu==0]*0.9
    #yu[yu==0] = 0.001
    #lam_mu = (v.sum(1)*data).sum(0)/yu
    #lam_var = (v.sum(1) * ((data-lam_mu)**2)).sum(0)/yu
    
    model['rho'] = rho
    model['alpha'] = alpha
    model['theta_mu'] = theta_mu
    model['theta_var'] = theta_var
    model['lam_mu'] = lam_mu
    model['lam_var'] = lam_var
    
    return model

In [529]:
def move_superfluous(model, k, d , data):
    #move the superfluous feature and component
    rho = model['rho']
    alpha = model['alpha']
    theta_mu = model['theta_mu']
    theta_var = model['theta_var']
    lam_mu = model['lam_mu']
    lam_var = model['lam_var']
    
    alpha0 = alpha<0.001
    if alpha0.sum()!=0:
        theta_mu = theta_mu[~alpha0,:]
        theta_var = theta_var[~alpha0,:]
        alpha = alpha[~alpha0]
        k = k-alpha0.sum()

    rho1 = rho==1
    if rho1.sum()!=0:
        lam_mu[rho1] = lam_mu[rho1]*0.9
        lam_var[rho1] = lam_var[rho1]*0.9
        rho[rho1] = 0.9

    rho0 = rho<0.001
    if rho0.sum()!=0:
        lam_mu = lam_mu[~rho0]
        lam_var = lam_var[~rho0]
        theta_mu = theta_mu[:,~rho0]
        theta_var = theta_var[:,~rho0]
        data = data[:,~rho0]
        rho = rho[~rho0]
        d = d-(rho0).sum()
    
    model['rho'] = rho
    model['alpha'] = alpha
    model['theta_mu'] = theta_mu
    model['theta_var'] = theta_var
    model['lam_mu'] = lam_mu
    model['lam_var'] = lam_var
    return model,k,d,data

In [530]:
def MML(data, m, d, k, model):
    # compute the cost
    a = np.zeros((m,k,d))
    b = np.zeros((m,k,d))
    for j in range(k):
        a[:,j,:] = model['rho'] * Single_Gaussian(data[j,:],model['theta_mu'][j,:],model['theta_var'][j,:])
        b[:,j,:] = (1-model['rho']) * Single_Gaussian(data[j,:] , model['lam_mu'] , model['lam_var'])
    c = a+b
    c[c>1] = 1
    yu = c.prod(2)*model['alpha']
    yu = yu.sum()
    yu = np.log(yu)
    #cost = -yu/m
    cost = -yu + d*np.log(alpha).sum() + np.log(1-rho).sum() + k*np.log(rho).sum()
    return cost

In [531]:
np.random.seed(15)
key = np.array([[0.0,1.0,6.0,7.0],[3.0,9.0,4.0,10.0]])
label = 4
key = np.row_stack( (key,np.random.normal(size=(8,4))) )
print(key)
std = np.array([1,1,1,1])
data =  key[:,0:1] + np.random.normal(size=(10,200))*std[0]
for i in range(1,label):
    data = np.column_stack( (data , key[:,i:i+1] + np.random.normal(size=(10,200))*std[i] ) )
data = data.T

[[  0.           1.           6.           7.        ]
 [  3.           9.           4.          10.        ]
 [ -0.31232848   0.33928471  -0.15590853  -0.50178967]
 [  0.23556889  -1.76360526  -1.09586204  -1.08776574]
 [ -0.30517005  -0.47374837  -0.20059454   0.35519677]
 [  0.68951772   0.41058968  -0.56497844   0.59939069]
 [ -0.16293631   1.6002145    0.6816272    0.0148801 ]
 [ -0.08777963  -0.98211784   0.12169048  -1.13743729]
 [  0.34900258  -1.85851316  -1.16718189   1.42489683]
 [  1.49656536   1.28993206  -1.81174527  -1.49830721]]


In [532]:
[m,d] = data.shape
k = 100
kmin = 4
cost =  0
oldcost = 10000

model = {}
#theta_mu = np.random.normal(size=(k,d))
kk = np.arange(m)
np.random.shuffle(kk)
model['theta_mu'] = data[kk[:k],:]
model['theta_var'] = np.random.uniform(size=(k,d))
model['lam_mu'] = data.mean(0)
model['lam_var'] = np.random.uniform(size=(d,1)).reshape(d)
model['rho'] = np.ones((1,d))*0.5
model['alpha'] = np.random.uniform(size=(1,k))
model['alpha'] = model['alpha']/model['alpha'].sum()

In [533]:
model['alpha']

array([[  1.32293382e-02,   1.57302193e-02,   1.29829359e-02,
          7.46900940e-03,   1.21090666e-02,   1.18229176e-02,
          3.60281188e-03,   1.41005905e-02,   1.50132651e-02,
          4.29517078e-03,   1.93570632e-02,   9.27868559e-03,
          1.28042258e-02,   1.36269496e-02,   1.71201415e-02,
          1.93832436e-02,   1.14521692e-02,   1.45683003e-02,
          1.07093719e-03,   7.37117536e-04,   1.94716846e-02,
          1.59483116e-03,   1.59149143e-02,   1.13155499e-02,
          1.89235611e-02,   1.65456874e-02,   1.15897516e-02,
          1.73443063e-03,   1.41604559e-02,   6.47767022e-03,
          4.51776637e-05,   3.12432214e-04,   7.03616052e-03,
          1.57938522e-02,   1.86220235e-02,   4.17944627e-03,
          4.81234341e-03,   6.17463216e-03,   8.67225359e-04,
          3.16596023e-03,   7.84788288e-04,   4.36818794e-03,
          1.14161082e-02,   1.51134014e-02,   1.73596436e-02,
          1.71443554e-02,   8.76785431e-03,   6.50454506e-03,
        

In [534]:
savemodel = []
while(k>kmin):
    print(k)
    step = 0
    oldcost = 10000
    while( (abs(cost-oldcost)>0.001) and (step<100) ):
        step = step+1
        [ w, u, v ] = E_step(data , m , k , d , model)
        model = M_step(data, w, u, v, m, k, d)

        [model,k,d,data] = move_superfluous(model, k, d,data)
        savemodel.append(model)
        print(k)
        if(k<kmin):
            model = savemodel[-2]
            k = model['alpha'].shape[0]
            break
        oldcost = cost
        cost = MML(data, m, d, k, model)
    
    if( (step==100) and (abs(cost-oldcost)>0.001) and (k>kmin)):
        alpha0 = model['alpha'].argmin()
        model['theta_mu'] = np.delete(model['theta_mu'], alpha0 , axis=0)
        model['theta_var'] = np.delete(model['theta_var'], alpha0 , axis=0)
        model['alpha'] = np.delete(model['alpha'], alpha0 , axis=0)
        savemodel.append(model)
        k = k-1
        
        

100
[[[  7.95470514e-02   9.29684046e-10   1.03589959e-02 ...,   4.33017046e-01
     2.62476076e-07   3.12752427e-01]
  [  3.93193634e-10   9.29684046e-10   6.72072640e-03 ...,   3.94224958e-01
     2.62476133e-07   2.19281327e-01]
  [  3.91069867e-10   9.29684046e-10   1.55792769e-01 ...,   6.02999820e-01
     2.39223197e-06   2.80997417e-01]
  ..., 
  [  2.62587117e-04   9.29684046e-10   6.96841173e-03 ...,   4.23690588e-01
     2.62606821e-07   2.19207763e-01]
  [  3.97219603e-10   5.31064077e-05   7.49103595e-03 ...,   2.91428798e-01
     2.62475902e-07   2.19206536e-01]
  [  3.91069867e-10   9.30137578e-10   6.71607956e-03 ...,   4.59118805e-01
     2.62475902e-07   2.19205219e-01]]

 [[  4.90154156e-02   2.87752885e-04   2.73096413e-01 ...,   1.92617323e-01
     2.68797485e-01   2.58841192e-01]
  [  4.97602291e-11   2.87752885e-04   2.56053625e-01 ...,   2.27816617e-01
     2.68589882e-01   1.98591717e-01]
  [  4.97254213e-11   2.87752887e-04   2.77745704e-01 ...,   5.85051052e-0

18
[[[  1.15589909e-04   2.85436256e-04   1.20830687e-01 ...,   4.09012400e-01
     2.80642819e-05   1.14125747e-01]
  [  3.84958603e-02   2.85436256e-04   5.18257925e-02 ...,   2.13507465e-01
     2.80128734e-05   1.52150404e-01]
  [  2.24247632e-02   2.85436256e-04   5.18258695e-02 ...,   2.32915985e-01
     2.80162546e-05   1.15514287e-01]
  ..., 
  [  1.15763203e-04   3.73363337e-01   1.58867144e-01 ...,   2.78268823e-01
     2.80128740e-05   1.28389003e-01]
  [  1.15589909e-04   3.32203162e-01   6.97936576e-02 ...,   3.71062978e-01
     2.80129111e-05   1.41312680e-01]
  [  2.30059989e-02   2.85436256e-04   5.18313230e-02 ...,   2.94887249e-01
     2.80128734e-05   1.05827655e-01]]

 [[  6.68145509e-05   3.77550765e-03   2.45095291e-01 ...,   1.82099959e-01
     1.33947299e-01   1.12729536e-01]
  [  2.18684956e-02   3.77553046e-03   2.30436447e-01 ...,   5.05505828e-01
     1.35650765e-01   1.20930577e-01]
  [  1.14980447e-02   3.77550765e-03   2.76973771e-01 ...,   1.14634552e-01

14
[[[  4.92295450e-02   1.03861863e-04   6.56436368e-02 ...,   2.16357140e-01
     9.54376862e-06   1.44752046e-01]
  [  3.54273478e-02   1.03861863e-04   6.56436552e-02 ...,   2.53138290e-01
     9.54699710e-06   1.06248549e-01]
  [  8.44337312e-02   2.63239592e-01   1.75526175e-01 ...,   3.47183457e-01
     2.35403137e-05   1.13640084e-01]
  ..., 
  [  1.20320960e-01   2.05026859e-01   1.54341092e-01 ...,   3.87066137e-01
     3.54756317e-03   9.97840266e-02]
  [  2.43490505e-05   2.80999192e-01   7.30058689e-02 ...,   3.84529526e-01
     9.54502593e-06   1.43125941e-01]
  [  4.19684358e-02   1.03861863e-04   6.56436514e-02 ...,   2.56986990e-01
     9.54377370e-06   1.00072776e-01]]

 [[  3.11196815e-02   1.76050825e-03   2.31527630e-01 ...,   4.47850331e-01
     1.18780327e-01   1.11086102e-01]
  [  1.97620683e-02   1.76049622e-03   2.64803608e-01 ...,   1.16872054e-01
     1.53901926e-01   8.99109099e-02]
  [  4.56333530e-02   2.29340639e-01   2.98012010e-01 ...,   1.19920875e-01

12
[[[  5.62287803e-02   1.21961070e-05   7.04112483e-02 ...,   2.21300336e-01
     6.43395784e-06   1.39498887e-01]
  [  8.65551993e-02   2.62752519e-01   1.84247779e-01 ...,   3.69204684e-01
     1.79048415e-05   1.13026761e-01]
  [  2.06843103e-01   5.98176925e-01   7.28557656e-02 ...,   4.48772726e-01
     1.93752095e-02   9.39453365e-02]
  ..., 
  [  1.23347583e-01   2.13285119e-01   1.57260649e-01 ...,   3.77397630e-01
     4.40855272e-03   9.43530455e-02]
  [  9.61669103e-07   2.64218304e-01   7.28141105e-02 ...,   4.00876368e-01
     6.44898379e-06   1.45204771e-01]
  [  4.31405925e-02   1.21961070e-05   7.04112483e-02 ...,   2.33259551e-01
     6.43637670e-06   9.80652516e-02]]

 [[  3.67996740e-02   5.63732810e-04   2.30648655e-01 ...,   3.90471839e-01
     1.08001856e-01   1.04950088e-01]
  [  4.72280094e-02   2.34148084e-01   2.87997950e-01 ...,   1.17359495e-01
     4.78228230e-01   9.04188738e-02]
  [  1.76884671e-01   2.76241370e-02   4.09450524e-01 ...,   1.18525626e-01

10
[[[  5.34255019e-02   2.11118464e-08   6.97884378e-02 ...,   2.31782744e-01
     8.31119018e-06   1.39763675e-01]
  [  9.03374353e-02   2.65784777e-01   1.87988279e-01 ...,   3.92937883e-01
     2.78922670e-05   1.06324158e-01]
  [  2.13392534e-01   6.81315304e-01   8.11255519e-02 ...,   4.20160042e-01
     2.44219922e-02   8.38529088e-02]
  ..., 
  [  1.38459400e-02   2.11118554e-08   7.02208378e-02 ...,   3.19502531e-01
     8.31208368e-06   1.02188477e-01]
  [  1.28760203e-01   2.25986978e-01   1.61869989e-01 ...,   3.79968647e-01
     5.76640311e-03   8.88192636e-02]
  [  3.93044917e-12   2.77189433e-01   7.02045082e-02 ...,   4.33678204e-01
     8.52376038e-06   1.42437174e-01]]

 [[  3.41410625e-02   3.26965688e-05   2.31529078e-01 ...,   3.44993862e-01
     1.03431720e-01   1.06026442e-01]
  [  4.99174543e-02   2.35646961e-01   2.82349579e-01 ...,   1.14357218e-01
     4.72161060e-01   8.35520987e-02]
  [  1.86476660e-01   1.10453696e-02   4.16933161e-01 ...,   1.01571989e-01

[[[  3.86246545e-02   4.19970768e-08   6.75159436e-02 ...,   2.38947751e-01
     1.06276842e-05   1.34602591e-01]
  [  9.09561421e-02   2.43174311e-01   1.87766952e-01 ...,   4.00290210e-01
     4.10058974e-05   9.90585215e-02]
  [  2.32025697e-01   6.54325725e-01   9.18603583e-02 ...,   3.57823818e-01
     3.15838694e-02   7.73553778e-02]
  ..., 
  [  1.70008508e-02   4.19970970e-08   6.77760258e-02 ...,   3.13997053e-01
     1.06290867e-05   9.79189575e-02]
  [  1.30583895e-01   2.11338230e-01   1.64913210e-01 ...,   3.82935070e-01
     6.59171433e-03   8.72489107e-02]
  [  4.69060449e-10   2.49577526e-01   6.77435047e-02 ...,   4.60944366e-01
     1.23776810e-05   1.19505639e-01]]

 [[  2.23542988e-02   1.41497307e-03   2.28633361e-01 ...,   3.39399257e-01
     9.81745973e-02   1.02660051e-01]
  [  5.00889020e-02   2.12551154e-01   2.81881864e-01 ...,   1.20031085e-01
     4.68404529e-01   7.77361724e-02]
  [  2.07018974e-01   6.88700221e-03   3.89279272e-01 ...,   9.07034476e-02
  

9
[[[  3.32609648e-02   5.16795559e-07   6.55435026e-02 ...,   2.47668839e-01
     1.10834997e-05   1.27380181e-01]
  [  9.97172967e-02   2.88877811e-01   1.84481648e-01 ...,   4.03943805e-01
     5.29593621e-05   9.11689544e-02]
  [  5.95720566e-11   5.16795569e-07   1.73258755e-01 ...,   2.36634828e-01
     7.31210073e-02   2.43854505e-01]
  ..., 
  [  2.29361808e-02   5.16795596e-07   6.57341519e-02 ...,   3.09834206e-01
     1.10853125e-05   9.53938375e-02]
  [  1.31057922e-01   2.52341206e-01   1.66402179e-01 ...,   3.84565725e-01
     7.72373826e-03   8.66421037e-02]
  [  5.95720530e-11   2.87066549e-01   6.56904258e-02 ...,   4.87324786e-01
     1.55035367e-05   9.60059837e-02]]

 [[  1.94044812e-02   7.48995441e-04   2.26637416e-01 ...,   3.34545828e-01
     9.54228927e-02   9.76378522e-02]
  [  5.63272687e-02   2.17976911e-01   2.84969362e-01 ...,   1.31885582e-01
     4.65614502e-01   7.18076340e-02]
  [  9.85796888e-12   7.48998351e-04   2.97662050e-01 ...,   2.93313749e-01


8
[[[  2.81102869e-02   2.21869797e-07   6.39541982e-02 ...,   2.57555214e-01
     1.06490589e-05   1.18961695e-01]
  [  1.10328014e-01   2.85579125e-01   1.80092526e-01 ...,   4.07122421e-01
     5.77497447e-05   8.39059401e-02]
  [  2.71165378e-15   2.21869810e-07   1.74097310e-01 ...,   2.40625081e-01
     7.10925698e-02   2.44725068e-01]
  ..., 
  [  8.57789042e-15   2.21869797e-07   2.05178279e-01 ...,   3.86764287e-01
     4.30039303e-04   1.28491741e-01]
  [  2.58015135e-02   2.21869863e-07   6.41163566e-02 ...,   3.06432310e-01
     1.06513156e-05   9.34469153e-02]
  [  1.51324326e-01   2.55683941e-01   1.70684022e-01 ...,   3.91222993e-01
     9.82127283e-03   8.66469175e-02]]

 [[  1.63078522e-02   1.60186143e-04   2.25226733e-01 ...,   3.29948585e-01
     9.32320954e-02   9.15229400e-02]
  [  6.40824489e-02   2.24667686e-01   2.89413760e-01 ...,   1.41212042e-01
     4.66833218e-01   6.65230570e-02]
  [  2.08139610e-16   1.60189403e-04   2.95548808e-01 ...,   2.93591691e-01


8
[[[  2.28568991e-02   2.33973875e-05   6.30230626e-02 ...,   2.65425678e-01
     9.75325847e-06   1.11959464e-01]
  [  9.45201797e-02   2.57892257e-01   1.76332014e-01 ...,   4.11236353e-01
     4.67197010e-05   7.83741121e-02]
  [  1.74542790e-16   2.33973875e-05   1.74893052e-01 ...,   2.42997979e-01
     6.92943670e-02   2.45411530e-01]
  ..., 
  [  1.03376377e-14   2.33973875e-05   2.07871994e-01 ...,   3.84631452e-01
     3.59345981e-04   1.19750141e-01]
  [  2.42405909e-02   2.33973876e-05   6.31776274e-02 ...,   3.02188758e-01
     9.75609513e-06   9.17113487e-02]
  [  1.36735563e-01   2.27956533e-01   1.75757772e-01 ...,   3.94518057e-01
     9.69373642e-03   8.76103895e-02]]

 [[  1.32495105e-02   8.88217275e-03   2.23638745e-01 ...,   3.28117274e-01
     8.94420066e-02   8.63424908e-02]
  [  5.33506547e-02   2.11957685e-01   2.93671566e-01 ...,   1.45685623e-01
     4.74011687e-01   6.24246041e-02]
  [  2.24580385e-17   8.88217585e-03   2.93646523e-01 ...,   2.93993801e-01


7
[[[  2.09959213e-02   3.56428562e-05   6.29544904e-02 ...,   2.71501445e-01
     8.68146665e-06   1.07253515e-01]
  [  9.73465988e-02   2.98353897e-01   1.73413034e-01 ...,   4.18033251e-01
     3.89808092e-05   7.41743799e-02]
  [  1.42747726e-17   3.56428562e-05   1.76600290e-01 ...,   2.45186191e-01
     6.72628923e-02   2.45556227e-01]
  ..., 
  [  4.51634570e-15   3.56428562e-05   2.08671109e-01 ...,   3.85420264e-01
     3.08097253e-04   1.09837749e-01]
  [  2.69359218e-02   3.56428563e-05   6.31262350e-02 ...,   2.98004519e-01
     8.68529501e-06   9.00915342e-02]
  [  1.52011499e-01   2.63898451e-01   1.79681048e-01 ...,   3.96260078e-01
     9.32819266e-03   8.86942100e-02]]

 [[  1.14388768e-02   3.02661970e-03   2.22856893e-01 ...,   3.28340437e-01
     8.62833351e-02   8.29477708e-02]
  [  5.32543153e-02   2.11115754e-01   2.97365831e-01 ...,   1.49124801e-01
     4.80128901e-01   5.93483567e-02]
  [  1.38769551e-18   3.02662093e-03   2.92905670e-01 ...,   2.94057705e-01


7
[[[  2.16347305e-02   5.15805193e-06   6.29306447e-02 ...,   2.75995842e-01
     8.37839701e-06   1.04690682e-01]
  [  8.91765324e-02   2.92875680e-01   1.68493983e-01 ...,   4.24939023e-01
     3.25675227e-05   7.07419881e-02]
  [  1.11716059e-16   5.15805193e-06   1.75751545e-01 ...,   2.48360883e-01
     6.50460349e-02   2.45531600e-01]
  ..., 
  [  5.41082913e-14   5.15805193e-06   2.12611976e-01 ...,   3.77415638e-01
     2.80991791e-04   9.93824940e-02]
  [  2.79404590e-02   5.15805205e-06   6.31348109e-02 ...,   2.92410219e-01
     8.38398198e-06   8.92561671e-02]
  [  1.53776855e-01   2.54771464e-01   1.84765552e-01 ...,   3.99968632e-01
     8.54166032e-03   8.99437539e-02]]

 [[  1.22908678e-02   2.23536619e-04   2.21931582e-01 ...,   3.30413941e-01
     7.89153986e-02   8.14888733e-02]
  [  4.67347511e-02   2.19302915e-01   3.03246436e-01 ...,   1.57032113e-01
     4.88553416e-01   5.71946755e-02]
  [  1.38334031e-17   2.23538102e-04   2.91714907e-01 ...,   2.93036950e-01


     1.99880418e-01   7.57689140e-04]]]
7
[[[  1.87112190e-02   4.76344553e-04   6.26745010e-02 ...,   2.79035631e-01
     9.11185182e-06   1.03075338e-01]
  [  7.99323165e-02   3.00835167e-01   1.64310811e-01 ...,   4.24233103e-01
     2.66144214e-05   6.74653879e-02]
  [  3.67928214e-17   4.76344553e-04   1.76495589e-01 ...,   2.53046305e-01
     6.22053857e-02   2.42885624e-01]
  ..., 
  [  3.03632811e-13   4.76344553e-04   2.14151570e-01 ...,   3.69081650e-01
     2.76146677e-04   9.01559767e-02]
  [  2.75334596e-02   4.76344553e-04   6.29297342e-02 ...,   2.88957245e-01
     9.12247828e-06   8.82812568e-02]
  [  1.52680994e-01   2.68022565e-01   1.89151997e-01 ...,   4.05470850e-01
     7.86182449e-03   9.05631812e-02]]

 [[  1.00837611e-02   6.36645501e-03   2.22077283e-01 ...,   3.29274268e-01
     7.33460022e-02   8.06173893e-02]
  [  3.98058947e-02   2.08183115e-01   3.08429973e-01 ...,   1.65594754e-01
     4.99495335e-01   5.50974397e-02]
  [  4.25868764e-18   6.36645570e-03

     2.00064716e-01   7.55337460e-04]]]
7
[[[  1.88560211e-02   5.05614899e-05   6.23656277e-02 ...,   2.80954671e-01
     1.26083157e-05   1.03085435e-01]
  [  6.44145050e-02   2.93636827e-01   1.60673127e-01 ...,   4.17013264e-01
     2.25377586e-05   6.45955084e-02]
  [  1.58223243e-16   5.05614899e-05   1.76552582e-01 ...,   2.57469254e-01
     6.01287149e-02   2.39204274e-01]
  ..., 
  [  5.23064524e-12   5.05614899e-05   2.15794609e-01 ...,   3.68477328e-01
     2.66503423e-04   8.26726850e-02]
  [  2.51210109e-02   5.05614900e-05   6.26591720e-02 ...,   2.86404220e-01
     1.26258161e-05   8.79901997e-02]
  [  1.37343963e-01   2.58698714e-01   1.92319772e-01 ...,   4.11172670e-01
     7.15237173e-03   8.99991075e-02]]

 [[  1.06966380e-02   5.43283968e-04   2.22232709e-01 ...,   3.27843956e-01
     6.73275791e-02   8.08651561e-02]
  [  3.01963150e-02   2.15931576e-01   3.12845713e-01 ...,   1.77139483e-01
     5.16082371e-01   5.32326968e-02]
  [  1.99404997e-17   5.43284932e-04

     2.00618592e-01   7.50984950e-04]]]
7
[[[  1.99405219e-02   2.84325912e-03   6.18882991e-02 ...,   2.83415654e-01
     1.73441893e-05   1.02780012e-01]
  [  6.17421873e-02   3.03371210e-01   1.58979582e-01 ...,   4.07660916e-01
     2.17500312e-05   6.18119494e-02]
  [  9.44532438e-17   2.84325912e-03   1.78734562e-01 ...,   2.63383737e-01
     5.70480218e-02   2.32743044e-01]
  ..., 
  [  3.27872774e-10   2.84325912e-03   2.13289733e-01 ...,   3.78094926e-01
     1.92856937e-04   8.24190216e-02]
  [  2.79520735e-02   2.84325912e-03   6.22389500e-02 ...,   2.85458443e-01
     1.73757112e-05   8.74894075e-02]
  [  1.51270462e-01   2.73025523e-01   1.94091534e-01 ...,   4.15879066e-01
     6.55327267e-03   8.96795482e-02]]

 [[  1.09810990e-02   1.27725886e-02   2.23183128e-01 ...,   3.25590006e-01
     6.24736317e-02   8.09133589e-02]
  [  2.68339444e-02   2.00685006e-01   3.15291968e-01 ...,   1.88849870e-01
     5.37844193e-01   5.14019330e-02]
  [  1.14426833e-17   1.27725890e-02

[[[  1.89017359e-02   8.15052812e-04   6.22548853e-02 ...,   2.86527964e-01
     2.33290664e-05   1.02489616e-01]
  [  4.82927407e-02   2.96103907e-01   1.60046370e-01 ...,   3.99796000e-01
     2.54584452e-05   5.95692755e-02]
  [  1.11316577e-15   8.15052812e-04   1.86741871e-01 ...,   2.72921840e-01
     5.22167618e-02   2.22615931e-01]
  [  1.43724234e-12   8.37717939e-02   1.12326099e-01 ...,   3.94465809e-01
     2.41644041e-05   1.41663595e-01]
  [  2.51516888e-02   8.15052812e-04   6.26393504e-02 ...,   2.87224936e-01
     2.33709594e-05   8.72611977e-02]
  [  1.35630664e-01   2.65882100e-01   1.95181015e-01 ...,   4.19734133e-01
     6.11154143e-03   8.86789472e-02]]

 [[  1.06741562e-02   3.46235066e-03   2.22352416e-01 ...,   3.23238177e-01
     5.80121899e-02   8.08255516e-02]
  [  1.96210949e-02   2.08771557e-01   3.14373942e-01 ...,   1.97812306e-01
     5.56220596e-01   4.98571751e-02]
  [  1.55899156e-16   3.46235121e-03   2.85545876e-01 ...,   2.80461441e-01
     1.096

6
[[[  1.99410173e-02   2.56154104e-05   6.07234864e-02 ...,   2.84170558e-01
     3.38145845e-05   1.03433406e-01]
  [  3.95678750e-02   2.95541262e-01   1.61557607e-01 ...,   3.81238069e-01
     3.44345078e-05   5.71310466e-02]
  [  9.22827050e-16   2.56154104e-05   1.86971236e-01 ...,   2.77304722e-01
     5.25404643e-02   2.22564946e-01]
  [  8.52385276e-13   8.38350430e-02   1.11655850e-01 ...,   3.93968489e-01
     3.49391273e-05   1.42009452e-01]
  [  2.70583141e-02   2.56154104e-05   6.11510030e-02 ...,   2.85186126e-01
     3.38733936e-05   8.78852320e-02]
  [  1.48878306e-01   2.62319118e-01   1.92170140e-01 ...,   4.21409488e-01
     5.41338238e-03   8.76115079e-02]]

 [[  1.08318498e-02   9.66411390e-05   2.25192455e-01 ...,   3.26619122e-01
     5.41852755e-02   8.21245821e-02]
  [  1.39113253e-02   2.10384458e-01   3.13923851e-01 ...,   2.14694836e-01
     5.85770981e-01   4.83581097e-02]
  [  1.27145524e-16   9.66420129e-05   2.87243154e-01 ...,   2.76434384e-01
     1.0

[[[  1.89908072e-02   3.87272133e-03   6.01023218e-02 ...,   2.82631740e-01
     4.12492673e-05   1.03308572e-01]
  [  2.45118558e-02   2.98814556e-01   1.65306058e-01 ...,   3.65618202e-01
     4.15464482e-05   5.53402333e-02]
  [  1.34703003e-15   3.87272133e-03   1.86507046e-01 ...,   2.78113185e-01
     5.16259618e-02   2.22897192e-01]
  [  1.40858489e-12   8.20351933e-02   1.10213496e-01 ...,   3.94035313e-01
     4.21760509e-05   1.41460473e-01]
  [  2.51921932e-02   3.87272133e-03   6.05514689e-02 ...,   2.84763110e-01
     4.13254654e-05   8.80453122e-02]
  [  1.33512008e-01   2.73556388e-01   1.90312489e-01 ...,   4.22369835e-01
     4.93036639e-03   8.70112622e-02]]

 [[  1.04865711e-02   9.65372485e-03   2.27519896e-01 ...,   3.28649568e-01
     5.29931944e-02   8.24474194e-02]
  [  7.26585558e-03   1.98374904e-01   3.10857741e-01 ...,   2.26660804e-01
     6.02689958e-01   4.72213542e-02]
  [  1.90750761e-16   9.65372504e-03   2.88817068e-01 ...,   2.76126128e-01
     1.073

6
[[[  1.84248274e-02   4.68444730e-04   6.01328903e-02 ...,   2.81997482e-01
     5.92025920e-05   1.05018034e-01]
  [  7.69635486e-03   2.90049649e-01   1.73871612e-01 ...,   3.32310011e-01
     5.93382375e-05   5.24877928e-02]
  [  1.05978718e-15   4.68444730e-04   1.85534369e-01 ...,   2.78505253e-01
     5.24296582e-02   2.23200509e-01]
  [  8.60462875e-13   8.27602345e-02   1.10961459e-01 ...,   3.94936144e-01
     6.04973245e-05   1.41481131e-01]
  [  2.66199046e-02   4.68444730e-04   6.05508434e-02 ...,   2.85228944e-01
     5.92863915e-05   8.94836476e-02]
  [  1.46940063e-01   2.68263924e-01   1.86875607e-01 ...,   4.22956521e-01
     4.18450893e-03   8.49686938e-02]]

 [[  9.18333783e-03   1.18369584e-03   2.31570181e-01 ...,   3.29277178e-01
     5.89667512e-02   8.51798228e-02]
  [  1.38253844e-03   2.12820998e-01   3.02491765e-01 ...,   2.48188434e-01
     6.17763370e-01   4.54531808e-02]
  [  1.47224651e-16   1.18369628e-03   2.89756455e-01 ...,   2.75709764e-01
     1.0

     1.93178140e-01   7.09894335e-04]]]
5
[[[  3.82518320e-04   2.79628831e-01   1.83217062e-01 ...,   2.71417597e-01
     8.28198599e-05   4.93743947e-02]
  [  8.04812013e-16   4.18363322e-06   1.85673474e-01 ...,   2.78427845e-01
     5.30234982e-02   2.23146817e-01]
  [  2.07499527e-13   8.38201334e-02   1.11868832e-01 ...,   3.96202131e-01
     8.43561590e-05   1.41503682e-01]
  [  2.30945163e-02   4.18363322e-06   6.03932625e-02 ...,   2.86764591e-01
     8.27839006e-05   9.71502869e-02]
  [  1.43577543e-01   2.70370745e-01   1.83135330e-01 ...,   4.22569471e-01
     3.42524601e-03   8.18441702e-02]]

 [[  2.54524056e-05   2.47815992e-01   2.87654248e-01 ...,   2.73016124e-01
     6.05786978e-01   4.37691269e-02]
  [  1.10117632e-16   2.27322374e-05   2.89512754e-01 ...,   2.75761600e-01
     1.08526732e-01   2.72602066e-01]
  [  3.15764486e-14   3.72367876e-01   3.59875303e-01 ...,   5.61807466e-02
     2.09336847e-01   1.77136429e-01]
  [  1.23784219e-02   2.27392835e-05   3.661

4
[[[  1.45216515e-15   2.38438624e-04   1.84665267e-01 ...,   2.76261894e-01
     5.10483836e-02   2.23756321e-01]
  [  1.06399736e-12   8.22487477e-02   1.10984947e-01 ...,   3.96928377e-01
     1.17046490e-04   1.40347251e-01]
  [  2.64059601e-02   2.38438624e-04   6.03225420e-02 ...,   2.86866410e-01
     1.15622986e-04   9.90083709e-02]
  [  1.38916143e-01   2.71275213e-01   1.84390014e-01 ...,   4.21066517e-01
     2.73817407e-03   7.87167427e-02]]

 [[  2.05648622e-16   1.43447035e-03   2.91570940e-01 ...,   2.77893203e-01
     9.92179267e-02   2.73500876e-01]
  [  1.79621436e-13   3.75345532e-01   3.61619156e-01 ...,   5.55297564e-02
     2.09113318e-01   1.76410938e-01]
  [  1.50881564e-02   1.43447430e-03   3.60373713e-01 ...,   2.90589282e-01
     7.39262808e-02   7.22509100e-02]
  [  9.37471597e-02   2.45278850e-01   2.84885970e-01 ...,   1.05491759e-01
     3.64486622e-01   6.02916694e-02]]

 [[  2.21510814e-10   9.78342933e-05   3.94178335e-01 ...,   2.27361768e-01
     8

4
[[[  1.02461965e-15   1.96718367e-06   1.83707664e-01 ...,   2.76098689e-01
     5.19081903e-02   2.23743724e-01]
  [  4.68770903e-13   8.32036171e-02   1.11783276e-01 ...,   3.99066488e-01
     1.49031820e-04   1.40105018e-01]
  [  2.49565295e-02   1.96718368e-06   6.04365603e-02 ...,   2.88030196e-01
     1.47200350e-04   9.88413407e-02]
  [  1.38374650e-01   2.69834631e-01   1.82361530e-01 ...,   4.20970057e-01
     2.75521961e-03   7.78976146e-02]]

 [[  1.42253151e-16   6.71150566e-05   2.92550780e-01 ...,   2.77905550e-01
     1.02153409e-01   2.73638813e-01]
  [  7.51322893e-14   3.72232567e-01   3.60106572e-01 ...,   5.37284285e-02
     2.11358759e-01   1.76501143e-01]
  [  1.38895158e-02   6.71214261e-05   3.57895473e-01 ...,   2.89047423e-01
     7.46378066e-02   7.18759295e-02]
  [  9.34248057e-02   2.52673873e-01   2.88495568e-01 ...,   1.06573649e-01
     3.64846348e-01   5.94490307e-02]]

 [[  1.76491036e-10   3.65201005e-07   3.96495235e-01 ...,   2.27693282e-01
     8

4
[[[  1.54586571e-15   5.15351625e-12   1.83640909e-01 ...,   2.75796221e-01
     5.24782810e-02   2.23523818e-01]
  [  1.35051014e-12   8.35894481e-02   1.11659209e-01 ...,   4.00814730e-01
     1.25654805e-04   1.40137245e-01]
  [  2.38585264e-02   5.15652849e-12   6.04622595e-02 ...,   2.88271058e-01
     1.23363822e-04   9.84925181e-02]
  [  1.25293061e-01   2.70067996e-01   1.81102913e-01 ...,   4.20985949e-01
     2.74501710e-03   7.74929445e-02]]

 [[  2.20886794e-16   5.14529238e-08   2.92879910e-01 ...,   2.78267811e-01
     1.04768817e-01   2.73599621e-01]
  [  2.33034317e-13   3.71363207e-01   3.59820437e-01 ...,   5.26235747e-02
     2.13313572e-01   1.76846669e-01]
  [  1.36668279e-02   5.84636283e-08   3.56202314e-01 ...,   2.88630443e-01
     7.56694498e-02   7.14891815e-02]
  [  8.46637221e-02   2.54430641e-01   2.90522653e-01 ...,   1.06687199e-01
     3.64959050e-01   5.90135436e-02]]

 [[  2.23085663e-10   6.92501943e-14   3.97176381e-01 ...,   2.27865211e-01
     9

4


In [404]:
model['theta_mu']

array([[ 3.71964773,  4.73824463, -0.09818195, -0.85280108, -0.10825509,
         0.27027388,  0.57029293, -0.59609696],
       [ 3.71964773,  4.73824463, -0.09818195, -0.85280108, -0.10825509,
         0.27027388,  0.57029293, -0.59609696],
       [ 3.71964773,  4.73824463, -0.09818195, -0.85280108, -0.10825509,
         0.27027388,  0.57029293, -0.59609696]])

In [400]:
k

0.07573777963996467

In [53]:
a = np.zeros((m,k,d))
b = np.zeros((m,k,d))
for j in range(k):
    for l in range(d):
        a[:,j,l] = rho[l] * stats.norm.pdf(data[:,l],theta_mu[j,l],theta_var[j,l])
        b[:,j,l] = (1-rho[l]) * stats.norm.pdf(data[:,l] , lam_mu[l] , lam_var[l])
c = a+b
c[c>0] = 1
yu = c.prod(2)*alpha
yu = yu.sum(1).sum(0)
yu = np.log(yu)
#cost = -yu/m
cost = -yu + d*np.log(alpha).sum() + np.log(1-rho).sum() + k*np.log(rho).sum()

In [54]:
np.log(alpha).sum()

-102.26131709656298

In [48]:
np.maximum([1,9],[3,6])

array([3, 9])

In [360]:
a = np.array([[[1,2],[3,4],[4,8]],[[5,6],[7,8],[8,7]]])
a.sum()

63

In [135]:
theta_var = np.random.uniform(size=(k,d))
a = np.random.shuffle(np.arange(10))
print(a)

None


In [134]:
arr = np.arange(10)
np.random.shuffle(np.arange(10))
arr

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [137]:
theta_mu.shape

(100, 10)