In [1]:
from numpy.linalg import inv
import numpy as np
import matplotlib.pyplot as plt
import math
from math import e
from math import log
from numpy import linalg
import sys
import sympy
from sympy import *
from IPython.display import display, HTML
import random
from math import log
from math import exp

%matplotlib inline
%config IPCompleter.greedy=True
sympy.init_printing(use_unicode=False, wrap_line=True)
np.set_printoptions(suppress=True)

from utils import *

In [2]:
# Load mnist data
print('Load Data...', end=' ')
imgs_train, labels_train = load_mnist(train=True)
imgs_test, labels_test = load_mnist(train=False)

# Change Image to Feature Vector
X_train, Y_train = img_to_vector(imgs_train), labels_train
X_test, Y_test = img_to_vector(imgs_test), labels_test

print('Finish!')

Load Data... Finish!


In [3]:
print(X_train.shape)
print(max(X_train[0]))

(60000, 784)
255.0


In [4]:
# E-Step : Compute the probabilities of cluster assignment (r_ik)
# M-step : Update parameters mu, pi givn r 

# X = [[x_11...x_1d], ..., [x_n1...x_nd]] where x_ij = 0/1 (success/fail)
# mu : [mu_1...mu_k] where mu_i is the vector of prob of success for cluster k , mu_i's shape = (1, D)
# pi : [pi_1 ... pi_k] where pi_i is the prob to draw cluster k
# r_iK : the prob(expectation) that Xi belong to cluster k
# Zi : [z_i1, ..., z_ik] binary k=dim data(assign to cluster k)

![title](formula.PNG)

In [5]:
num_cluster = 10
K = num_cluster
X = np.copy(X_train)[:]
Y = np.copy(Y_train)[:]
N, D = X.shape
# init parameters mu & pi
mu = np.random.random((K, D))
pi = np.random.random((K, 1)) 
pi = pi / pi.sum()
r = np.zeros((N,K))
print(X.max())
print(X.min())

255.0
0.0


In [6]:
# 0~255 to 0 or 1
X[X<128.0] = 0
X[X>=128.0] = 1
print(X.max())
print(X.min())

1.0
0.0


In [9]:
from numpy import prod
from tqdm import tqdm_notebook as tqdm
import time
def L2distance(A,B):
    A = A.reshape(prod(A.shape))
    B = B.reshape(prod(B.shape))
    dis = math.sqrt(np.dot(A-B, A-B))
    return dis

def EM(X, mu, pi, r, max_iter=100):
    
    N, D = X.shape
    K, _ = mu.shape
    
    new_mu = np.copy(mu)
    
    
    for it in range(max_iter):
        # E-Step : Compute the probabilities of cluster assignment (r_ik)
        pbar = tqdm(total=N, unit='instance') # for progress bar
        pbar.set_description('E-Step(iteration %d)' % (it)) # for progress bar
        for i in range(N):
            for k in range(K):
                r[i][k] = log(pi[k]) # Log scale
                for d in range(D):
                    xid = X[i][d]
                    try:
                        r[i][k] += log((mu[k][d]**xid) * ((1-mu[k][d])**(1-xid))+1e-7)
                    except:
                        print('domain error')
                        print(mu[k][d], xid)
                        print((mu[k][d]**xid) * ((1-mu[k][d])**(1-xid)))
            r[i] -= r[i].max()
            r[i] = np.exp(r[i]) # Exp, back to origin scale
            r[i] = r[i] / r[i].sum() # normalize to 1 
            
            pbar.update() # for progress bar
        Nk = r.sum(axis=0)  # prob to draw k-th cluster
        pi = Nk/Nk.sum()
        pbar.close() # for progress bar
        time.sleep(1) # for progress bar
        # M-step : Update parameters mu, pi givn r 
        pbar = tqdm(total=K, unit='cluster') # for progress bar
        pbar.set_description('M-Step(iteration %d)' % (it)) # for progress bar
        for k in range(K):
            mu_k = 0
            for i in range(N):
                mu_k += r[i][k] * X[i]
            new_mu[k] = mu_k / Nk[k]
            pbar.update() # for progress bar
        diff = L2distance(new_mu, mu)
        print('L2 Distance Between old & new mu : ', diff ) 
        mu = np.copy(new_mu)
        if diff < 1e-5:
            print('converge after %d iteration'%(it))
            break
        pbar.close() # for progress bar
        time.sleep(1) # for progress bar
    return mu, pi, r

def EM_inference(X, mu, pi):
    N, D = X.shape
    K, _ = mu.shape
    y_pred = np.zeros((N,))
    pbar = tqdm(total=N, unit='instance') # for progress bar
    pbar.set_description('E-Step(iteration %d)' % (it)) # for progress bar
    for i in range(N):
        
        for k in range(K):
            r[i][k] = log(pi[k]) # Log scale
            for d in range(D):
                xid = X[i][d]
                try:
                    r[i][k] += log((mu[k][d]**xid) * ((1-mu[k][d])**(1-xid))+1e-7)
                except:
                    print('domain error')
                    print(mu[k][d], xid)
                    print((mu[k][d]**xid) * ((1-mu[k][d])**(1-xid)))
#         print(r[i])
        y_pred[i] = np.argmax(r[i]) 
        pbar.update() # for progress bar
    pbar.close() # for progress bar
    time.sleep(1) # for progress bar
    return y_pred

In [10]:
mu, pi, r = EM(X, mu, pi, r)

HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

44.08816238557957



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

7.6151550469241975



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

4.834874767562426



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

3.4503743493086545



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

2.06743373221047



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

1.3929488558055187



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

1.0004792232979052



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.7793532219280064



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.6760161484226596



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.6365821725327467



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.6214602614789104



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.6341445852757264



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.6545410978833524



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.6493978857886389



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.5858376339875957



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.5558621260803037



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.4984871766993327



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.45486057804739155



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.4096975429245441



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.3696663301180982



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.34151522043795035



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.3191563507757973



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.29614530910124487



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.2875507371270962



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.2735214331564815



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.2560462282936223



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.24891406313743483



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.2742699411035644



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.27986041462286637



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.28943261233363005



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.30591676138750545



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.31665749206774146



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.32301524418176264



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.3246256649277126



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.30210684616295586



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.2791610195163686



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.25135946542752213



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.21421733055091433



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.190678105621654



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.19377916075146312



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.19383779265947523



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.17520436945616638



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.14510081348925902



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.1171090324343892



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.09316916586955304



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.07803069436663135



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.06689343659549356



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.05526347857743265



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.045436077877356054



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.037362360737272865



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.030874364678400403



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.026005525224129047



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.022756159377046704



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.020275529578071553



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.018231369086506



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.016694374857260465



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.015513925567247927



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.014915793226836487



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.014302699565744807



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.013884447314246168



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.013131233720693643



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.012207014153497864



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.011323055364273273



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.010480732778109724



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.009720274532042807



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.00906513023989921



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.008521765543748746



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0080865718092452



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.007755582218898565



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.00751329632326332



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.007310683033069174



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.007141023334946896



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.007028091490021577



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.006981863708501603



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.006997353650713018



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.007049871977334805



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.007098070088044218



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.007123029311223163



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0071162916621864275



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0070633169518354326



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.006953423217483521



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.006783476996813323



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0065640457360261785



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.006473711988444342



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.006046719766957037



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.005634716291807656



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.005198782447873531



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0047480644082287275



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.004298134228997266



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0038671425187720423



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0034669605168538113



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.003102673130627031



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.002775271032163038



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0024835590196581026



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0022253086953990483



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.001997980164328339



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0017991756577878674



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0016269395171545854



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0014799741038355474



HBox(children=(IntProgress(value=0, max=60000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0.0013577385073864194



In [None]:
y_pred = EM_inference(X, mu, pi)

In [None]:
from sklearn.metrics import confusion_matrix

count_y = [np.count_nonzero(Y == i) for i in range(10)]
count_y_pred = [np.count_nonzero(y_pred == i) for i in range(num_cluster)]
print(count_y)
print(count_y_pred)

print(confusion_matrix(y_pred, Y))

In [None]:
print(r.shape)
print(mu.shape)

In [None]:
# print(Nk)
# print(Nk.sum())
# print(Nk/Nk.sum())

In [None]:
print(mu.max())
print(mu.min())

In [None]:
for i in range(num_cluster):
    plt.subplot(2,5,i+1)
    p = mu[i].reshape((28,28))
    plt.imshow(p, cmap='gray'), plt.xticks([]), plt.yticks([])
plt.show()