In [13]:
import math
import random
import numpy as np
import matplotlib.pyplot as plt
import struct
import gzip
%matplotlib inline

# EM algorithm

In [14]:
# data
def read_IDX(fname):
    with gzip.open(fname) as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)


def pic_reshape(matrix):
    d = matrix.shape[0]
    dd = matrix.shape[1]
    ddd = matrix.shape[2]
    matrix = matrix.reshape((d, dd*ddd))
    return matrix

In [15]:
train_image = read_IDX('train-images-idx3-ubyte.gz')
train_label = read_IDX('train-labels-idx1-ubyte.gz')

In [16]:
print(train_image.shape)
print(train_label.shape)
train_image = pic_reshape(train_image)
train_label = train_label.reshape(len(train_label))

print(train_image.shape)
print(train_label.shape)

(60000, 28, 28)
(60000,)
(60000, 784)
(60000,)


In [17]:
# Let train_x become (60000,784) 0-1 matrix
def two_bins(x):
    '''
    x : (60000,784)
    '''
    # let x --> 0,1 matrix
    new = np.zeros((60000, 784))
    for pic in range(len(x)):
        for j in range(28*28):
            new[pic][j] = np.int(x[pic, j]//128)
    return new

In [18]:
def pixel_bernoulli(x, y):
    '''
    x : (60000,784)
    y : (60000,)
    return (10,784) probability
    '''
    ans = np.zeros(10)
    bernoulli = np.zeros((10, 784))
    for i in y:
        ans[i] += 1

    for i in range(len(x)):
        label = y[i]
        for j in range(28*28):
            if x[i, j] == 1:
                bernoulli[label, j] += 1

    '''
    bernoulli: (10,784)
    ans:(10,)
    ans.reshape(-1,1): (10,1)
    '''
    # normalization
    bernoulli = bernoulli/ans.reshape(-1, 1)

    return bernoulli


def plot_bernoulli(bernoulli):
    '''
    bernoulli: (10,784)
    '''
    for i in range(10):
        print('Class ', i, ':')
        for j in range(28):
            for d in range(28):
                if bernoulli[i, j*28+d] > 0.5:
                    print(1, end='')
                else:
                    print(0, end='')
            print()
        print()
        print()

In [19]:
# utility

def plot_pattern(pattern):
    '''
    :param pattern: (784)
    :return:
    '''
    for i in range(28):
        for j in range(28):
            print(pattern[i*28+j], end=' ')
        print()
    print()
    print()
    return


def plot(distribution, classes_order, threshold):
    '''
    plot each classes expected pattern
    distribution: (10,784)
    classes_order: (10)
    threshold: value between 0.0~1.0
    :return:
    '''
    Pattern = np.asarray(distribution > threshold, dtype='uint8')
    for i in range(10):
        print('class {}:'.format(i))
        plot_pattern(Pattern[classes_order[i]])
    return


def plot_confusion_matrix(c, TP, FN, FP, TN):
    print('------------------------------------------------------------')
    print()
    print('Confusion Matrix {}:'.format(c))
    print('\t\t\t  Predict number {} Predict not number {}'.format(c, c))
    print('Is number  \t{}\t\t{}\t\t\t\t{}'.format(c, TP, FN))
    print('Isn\'t number {}\t\t{}\t\t\t\t{}'.format(c, FP, TN))
    print()
    print('Sensitivity (Successfully predict number {}    ): {:.5f}'.format(c, TP/(TP+FN)))
    print('Specificity (Successfully predict not number {}): {:.5f}'.format(c, TN/(TN+FP)))
    print()


# confusion matrix
def confusion_matrix(real, predict, classes_order):
    '''
    :param real: (60000)
    :param predict: (60000)
    :param classes_order: (10)
    :return:
    '''
    for i in range(10):
        c = classes_order[i]
        TP, FN, FP, TN = 0, 0, 0, 0
        for i in range(60000):
            if real[i] != c and predict[i] != c:
                TN += 1
            elif real[i] == c and predict[i] == c:
                TP += 1
            elif real[i] != c and predict[i] == c:
                FP += 1
            else:
                FN += 1
        plot_confusion_matrix(c, TP, FN, FP, TN)


def print_error_rate(count, real, predict, classes_order):
    '''
    :param count: int
    :param real: (60000)
    :param predict: (60000)
    :param classes_order: (10)
    :return:
    '''
    print('Total iteration to converge: {}'.format(count))
    real_transform = np.zeros(60000)
    for i in range(60000):
        real_transform[i] = classes_order[real[i]]
    error = np.count_nonzero(real_transform-predict)
    print('Total error rate: {}'.format(error/60000))

In [20]:
# E-step

def update_posterior(x, lam, distribution):
    '''
    update posterior using log likelihood
    x: (60000,784) 0-1 uint8 matrix
    lam: (10,1)
    distribution: (10,784)
    return: (60000,10) matrix
    '''
    distribution_c = 1-distribution
    w = np.zeros((60000, 10))
    for i in range(60000):
        for j in range(10):
            w[i, j] = np.prod(x[i]*distribution[j]+(1-x[i])*distribution_c[j])
    # add prior
    w = w*lam.reshape(1, -1)

    add = np.sum(w, axis=1).reshape(-1, 1)
    add[add == 0] = 1
    w = w/add

    return w

In [21]:
# M-step

def lambda_update(w):
    '''
    W: (60000,10)
    return: (10,1)
    '''
    lam = np.sum(w, axis=0)
    lam = lam/60000
    return lam.T


def distribution_update(A, w):
    '''
    A.T@W -> normalized, transpose -> concate with 1-complement
    A: (60000,784)
    w: (60000,10)
    return: (10,784)
    '''
    # normalized w
    add = np.sum(w, axis=0)
    add[add == 0] = 1
    w_normalized = w/add
    prob = A.T@w_normalized

    return prob.T

In [22]:
# Here
from scipy.optimize import linear_sum_assignment


def distance(x, y):
    '''
    x: (784)
    y: (784)
    return: euclidean distance between x and y
    '''
    return np.linalg.norm(x-y)


def hungarian_algo(c):
    '''
    match GT to our estimate
    c: (10,10)
    return: (10) column index
    '''
    row_idx, col_idx = linear_sum_assignment(c)
    return col_idx


def perfect_matching(ground_truth, estimate):
    '''
    matching GT_distribution to estimate_distribution by minimizing the sum of the distance
    ground_truth: (10,784)
    estimate: (10,784)
    return: (10)
    '''
    cost = np.zeros((10, 10))
    for i in range(10):
        for j in range(10):
            cost[i, j] = distance(ground_truth[i], estimate[j])

    classes_order = hungarian_algo(cost)

    return classes_order

In [23]:
# initial

def init_lambda():
    '''
    lambda[k]= prior of class k
    sum(lambda)=1
    :return: (10,1) matrix
    '''
    re = np.random.rand(10)
    re = re/np.sum(re)
    return re

In [24]:
# main

eps = 1

A = two_bins(train_image)
b = train_label

# init, lambda represent by L
L = init_lambda()
P = np.random.rand(10, 784)  # Distribution(784-dim) of each class

last_diff = 1000
diff = 100
count = 0
while abs(last_diff-diff) > eps and diff > eps:
    # E-step (calculate posterior)
    w = update_posterior(A, L, P)

    # M-step (update L,P)
    L_new = lambda_update(w)
    P_new = distribution_update(A, w)
    # calculate diff
    last_diff = diff
    diff = np.sum(np.abs(L-L_new))+np.sum(np.abs(P-P_new))
    #print('Difference', diff)
    #print('Lambda:', L_new.reshape(1, -1)[0])
    L = L_new
    P = P_new
    count += 1
    print('No. of Iteration: ', count, ', Difference: ', diff)


# take a view of classes belonging (but not exactly class)
maxs = np.argmax(w, axis=1)
unique, counts = np.unique(maxs, return_counts=True)
print(dict(zip(unique, counts)))
#print('Lambda:', L.reshape(1, -1))

# plot classes predict & confusion matrix
GT_distribution = pixel_bernoulli(A, b)

class_order = perfect_matching(GT_distribution, P)

plot(P, class_order, threshold=0.35)
confusion_matrix(b, maxs, class_order)
print_error_rate(count, b, maxs, class_order)

No. of Iteration:  1 , Difference:  3186.2694577495945
No. of Iteration:  2 , Difference:  234.5329462137064
No. of Iteration:  3 , Difference:  133.50444231728687
No. of Iteration:  4 , Difference:  72.52342241917118
No. of Iteration:  5 , Difference:  44.06523194263516
No. of Iteration:  6 , Difference:  31.260411867554005
No. of Iteration:  7 , Difference:  25.96872715821486
No. of Iteration:  8 , Difference:  23.455145247718853
No. of Iteration:  9 , Difference:  20.585044173529237
No. of Iteration:  10 , Difference:  17.74424842215053
No. of Iteration:  11 , Difference:  15.695800396205172
No. of Iteration:  12 , Difference:  14.457125470706304
No. of Iteration:  13 , Difference:  13.187414217908763
No. of Iteration:  14 , Difference:  11.833350756630633
No. of Iteration:  15 , Difference:  10.939051881226604
{0: 5214, 1: 4610, 2: 6923, 3: 4305, 4: 8225, 5: 8372, 6: 4765, 7: 7432, 8: 2, 9: 10152}
class 0:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0

0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 


class 6:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 1 1 

------------------------------------------------------------

Confusion Matrix 7:
			  Predict number 7 Predict not number 7
Is number  	7		336				5929
Isn't number 7		7096				46639

Sensitivity (Successfully predict number 7    ): 0.05363
Specificity (Successfully predict not number 7): 0.86794

------------------------------------------------------------

Confusion Matrix 1:
			  Predict number 1 Predict not number 1
Is number  	1		3				6739
Isn't number 1		4607				48651

Sensitivity (Successfully predict number 1    ): 0.00044
Specificity (Successfully predict not number 1): 0.91350

------------------------------------------------------------

Confusion Matrix 5:
			  Predict number 5 Predict not number 5
Is number  	5		126				5295
Isn't number 5		8246				46333

Sensitivity (Successfully predict number 5    ): 0.02324
Specificity (Successfully predict not number 5): 0.84892

------------------------------------------------------------

Confusion Matrix 6:
			  Predict number 6 Pred