In [214]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [215]:
#Reading in the data
data = np.load('MC_500_5_4.npz.npy')
trueCls = np.load('MC_500_5_4_reference_classes.npy')

In [216]:
data.shape

(500, 5, 4)

In [254]:
#Inferring the number of classes and data length
numData = data.shape[0]
numExperts = data.shape[1]
numClasses = data.shape[2]

In [255]:
#Initialization
alpha = [1, 1, 1, 1]
z = np.random.dirichlet(alpha, size=numData)
observed_classes = np.argmax(z, axis=1)
lam = 10
lambda_mat = lam*np.eye( numClasses) + np.ones((numClasses, numClasses))
ConfMatList = []
for exp_index in range(numExperts):
    confMat = []
    for cls_idx in range(numClasses):
        confMat.append(np.random.dirichlet(lambda_mat[cls_idx,:],1)[0])
    ConfMatList.append(np.array(confMat))

In [256]:
def calc_log_like(data, confMatList, z):
    log_like = 0
    numData = data.shape[0]
    numExperts = data.shape[1]
    numClasses = data.shape[2]
    for data_idx in range(numData):
        dat = data[data_idx]
        observed_cls = z[data_idx]
        for expert_idx in range(numExperts):
            dat_exp = dat[expert_idx]
            confMat = confMatList[expert_idx]
            for class_idx_row in range(numClasses):
                for class_idx_column in range(numClasses):
                    log_like += observed_cls[class_idx_row]\
                    *np.sum(np.multiply(dat_exp ,np.log(confMat[class_idx_row])))
    return log_like

In [257]:
calc_log_like(data, ConfMatList, z)

-24529.784860181484

In [258]:
def E_Step(data, ConfMatList, z, observed_classes):
    numData = data.shape[0]
    numExperts = data.shape[1]
    numClasses = data.shape[2]
    
    posterior = np.zeros(z.shape)
    unique_classes = np.unique(observed_classes)
    class_dist = {}
    for cls in unique_classes:
        class_dist[cls] = float(len(observed_classes == cls))/numData
            
    for d_idx in range(numData):
        dat = data[d_idx]
        prior_z = z[d_idx]
        obs_class = observed_classes[d_idx]
        for cls_idx in range(numClasses):
            for exp_idx in range(numExperts):
                posterior[d_idx, cls_idx] += dat[exp_idx, cls_idx] * \
                np.log(ConfMatList[exp_idx][obs_class,cls_idx])
            posterior[d_idx, cls_idx] += np.log(class_dist[cls_idx])
    row_sum = posterior.sum(axis=1)
    return posterior/row_sum[:, np.newaxis]

In [259]:
def M_Step(data, confMatList, z, observed_classes):
    numData = data.shape[0]
    numExperts = data.shape[1]
    numClasses = data.shape[2]
    
    newConfMatList = [0.01*np.ones((numClasses, numClasses)) for idx in range(numExperts)]
    for e_idx in range(numExperts):
        for row_cls_idx in range(numClasses):
            dat = data[observed_classes == row_cls_idx]
            z_dat = z[observed_classes == row_cls_idx]
            for cls_idx in range(numClasses):
                for dat_idx in range(len(dat)):
                    newConfMatList[e_idx][row_cls_idx, cls_idx] += \
                    dat[dat_idx][e_idx][cls_idx]*z_dat[dat_idx][cls_idx]
        row_sum = newConfMatList[e_idx].sum(axis = 0)
        newConfMatList[e_idx] = newConfMatList[e_idx]/row_sum[:, np.newaxis]
    return newConfMatList

In [260]:
new_z = E_Step(data, ConfMatList, z, observed_classes)
new_observed_classes = np.argmax(z, axis=1)

In [261]:
newConfMatList = M_Step(data, ConfMatList, new_z, new_observed_classes)

In [262]:
calc_log_like(data, newConfMatList, new_z)

-17044.463463475273

In [264]:
for i in range(10):
    z = E_Step(data, ConfMatList, z, observed_classes)
    observed_classes = np.argmax(z, axis=1)
    ConfMatList = M_Step(data, ConfMatList, z, observed_classes)
    print calc_log_like(data, ConfMatList, z)

-17633.6138018
-21433.5843091
-17652.5149441
-21445.8471023
-17656.7596778
-21448.75493
-17657.7321385
-21449.43183
-17657.9557317
-21449.5883809


In [242]:
np.sum(np.argmax(z, axis=1) == trueCls)

133

In [243]:
np.argmax(z, axis=1), trueCls

(array([1, 2, 0, 1, 2, 0, 2, 1, 0, 3, 2, 3, 1, 2, 2, 3, 0, 1, 0, 2, 1, 0, 3,
        2, 3, 0, 0, 2, 1, 1, 2, 0, 3, 0, 0, 1, 2, 3, 2, 3, 0, 3, 0, 1, 0, 2,
        1, 3, 0, 1, 0, 0, 0, 2, 3, 1, 1, 3, 3, 3, 2, 1, 0, 1, 0, 1, 2, 2, 0,
        2, 2, 3, 2, 2, 0, 0, 1, 3, 2, 0, 3, 3, 3, 1, 2, 3, 1, 1, 3, 3, 1, 0,
        3, 3, 2, 3, 0, 2, 2, 3, 2, 0, 0, 2, 0, 3, 1, 2, 3, 3, 3, 1, 0, 1, 0,
        3, 0, 0, 3, 3, 1, 1, 0, 3, 2, 3, 2, 0, 2, 1, 2, 1, 3, 0, 3, 1, 2, 1,
        1, 1, 0, 0, 1, 0, 2, 0, 3, 2, 2, 2, 1, 1, 1, 3, 2, 2, 1, 1, 1, 1, 1,
        2, 1, 2, 0, 0, 0, 3, 3, 2, 1, 0, 1, 1, 2, 0, 3, 0, 2, 1, 2, 0, 0, 1,
        3, 0, 3, 0, 0, 3, 1, 2, 3, 0, 0, 0, 3, 1, 2, 3, 2, 1, 3, 3, 3, 1, 1,
        3, 0, 0, 0, 1, 0, 1, 3, 1, 3, 0, 0, 3, 3, 2, 2, 1, 0, 3, 1, 1, 0, 2,
        0, 0, 2, 2, 0, 1, 3, 1, 3, 0, 2, 3, 3, 0, 2, 3, 2, 3, 3, 0, 0, 1, 3,
        2, 0, 0, 2, 2, 1, 3, 3, 3, 0, 3, 1, 3, 2, 3, 3, 1, 0, 1, 2, 3, 0, 1,
        1, 0, 3, 0, 3, 3, 2, 2, 3, 0, 0, 0, 3, 3, 0, 1, 2, 2, 2, 2, 3, 3, 3,

# 