In [302]:
import torchvision
import torch
from torchvision import datasets, transforms
import numpy as np

MAX_ITER = 500
EPSILON = 1e-10
TOLERANCE = 1e-5

In [260]:
def initializeModel(K, d):
    np.random.seed(0)
    pi = np.random.rand(K)
    pi = pi / np.sum(pi)

    np.random.seed(0)
    mu = np.random.normal(0, 3, size=(K, d))

    np.random.seed(0)
    S = np.random.rand(K, d) + 0.5

    return pi, mu, S



In [261]:
print(initializeModel(4, 3))
a, b, c = initializeModel(4,3)

(array([0.2275677 , 0.29655611, 0.24993822, 0.22593797]), array([[ 5.29215704,  1.20047163,  2.93621395],
       [ 6.7226796 ,  5.60267397, -2.93183364],
       [ 2.85026525, -0.45407162, -0.30965656],
       [ 1.23179551,  0.43213071,  4.36282052]]), array([[1.0488135 , 1.21518937, 1.10276338],
       [1.04488318, 0.9236548 , 1.14589411],
       [0.93758721, 1.391773  , 1.46366276],
       [0.88344152, 1.29172504, 1.02889492]]))


In [298]:
def GMM(X, K_RANGE):
    N, d = X.shape
    pi, mu, S = initializeModel(K_RANGE, d)
    r = np.zeros((N, K_RANGE))
    loss = [0.0] * MAX_ITER

    for iter in range(MAX_ITER):
        for k in range(K_RANGE):
            exp_power = np.dot((X-mu[k]) ** 2, 1/S[k]) * (-1/2)
            # if iter==2:
                # print(pi[k] * np.power(np.prod(S[k]), -1/2) * np.exp(exp_power))
            r[:,k] = pi[k] * np.power(np.prod(S[k]), -1/2) * np.exp(exp_power)
        r_total = np.sum(r, axis=1)
        r = r / r_total[:,None]
        loss[iter] = -np.sum(np.log(r_total + EPSILON))

        if iter > 1 and abs(loss[iter] - loss[iter-1]) <= TOLERANCE * abs(loss[iter]):
            break
        
        r_total_i_wise = np.sum(r, axis=0)
        pi = r_total_i_wise / N
        mu = np.dot(r.T, X) / r_total_i_wise[:,None]
        S = np.dot(r.T, X ** 2) / r_total_i_wise[:,None] - mu ** 2 + EPSILON

    return pi, mu, S, loss

In [311]:
X = np.loadtxt(open("gmm_dataset.csv", "rb"), delimiter=",")

K = 10
pis, mus, Ss, losses = [[None for _ in range(K)] for _ in range(4)]
for k in range(10):
    pi, mu, S, loss = GMM(X, k)
    pis[k] = pi
    mus[k] = mu
    Ss[k] = S
    losses[k] = [ val for val in loss if val > 0.0]


In [314]:
for loss in losses:
    print(loss[-1])

115129.25464970224
69690.09745993733
57911.14954814103
51217.83275597244
41200.11275456188
34501.16250346653
34267.03191816248
34226.913473276094
34194.12075254171
34157.41194223416


In [316]:
transform = transforms.Compose([transforms.ToTensor()])
data_train = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size = 64,
                                                shuffle = True)
dataset = data_loader_train.dataset

In [317]:
for identify_cls in range(10):
    



tensor([5, 0, 4,  ..., 5, 6, 8])