In [56]:
import torchvision
import torch
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm
from scipy.special import logsumexp
from multiprocessing import Pool

MAX_ITER = 10
EPSILON = 1e-7
TOLERANCE = 1e-5

In [57]:
def initializeModel(K, d):
    np.random.seed(0)
    pi = np.random.rand(K)
    pi = pi / np.sum(pi)

    np.random.seed(0)
    mu = np.random.normal(0, 3, size=(K, d))

    np.random.seed(0)
    S = np.random.rand(K, d) + 0.5 + EPSILON

    return pi, mu, S



In [58]:
print(initializeModel(4, 3))
a, b, c = initializeModel(4,3)

(array([0.2275677 , 0.29655611, 0.24993822, 0.22593797]), array([[ 5.29215704,  1.20047163,  2.93621395],
       [ 6.7226796 ,  5.60267397, -2.93183364],
       [ 2.85026525, -0.45407162, -0.30965656],
       [ 1.23179551,  0.43213071,  4.36282052]]), array([[1.0488136 , 1.21518947, 1.10276348],
       [1.04488328, 0.9236549 , 1.14589421],
       [0.93758731, 1.3917731 , 1.46366286],
       [0.88344162, 1.29172514, 1.02889502]]))


In [69]:
def GMM(X, K_RANGE):
    N, d = X.shape
    pi, mu, S = initializeModel(K_RANGE, d)
    log_r = np.zeros((N, K_RANGE))
    loss = [0.0] * MAX_ITER

    # for iter in tqdm(range(MAX_ITER), total=MAX_ITER):
    for iter in range(MAX_ITER):
        print(f"{iter}th pi: {pi}")
        print(f"{iter}th S: {S}")
        print(f"{iter}th {np.log(pi[3])}")
        for k in range(K_RANGE):
            # exp_power = np.dot((X-mu[k]) ** 2, 1/S[k]) * (-1/2)
            # if iter==2:
                # print(pi[k] * np.power(np.prod(S[k]), -1/2) * np.exp(exp_power))
            # r[:,k] = pi[k] * np.power(np.prod(S[k]), -1/2) * np.exp(exp_power)
            log_r[:,k] = np.log(pi[k]) - 0.5 * np.sum(np.log(S[k])) - 0.5 * np.dot((X-mu[k]) ** 2, 1/S[k])
        r_total = logsumexp(log_r, axis=1)
        log_r = log_r - r_total[:,None]
        loss[iter] = -np.sum(r_total)

        if iter > 1 and abs(loss[iter] - loss[iter-1]) <= TOLERANCE * abs(loss[iter]):
            break
        
        r = np.exp(log_r)
        r_dot_k = np.exp(logsumexp(log_r, axis=0))
        print(r_dot_k)
        pi = r_dot_k / N
        mu = np.dot(r.T, X) / r_dot_k[:,None]
        S = np.dot(r.T, X ** 2) / r_dot_k[:,None] - mu ** 2 + EPSILON

    return pi, mu, S, loss

In [70]:
data_train = datasets.MNIST(root = "./data/",
                        transform=transform,
                        train = True,
                        download = True)
idx = data_train.targets == 0
np_X = data_train.data[idx].numpy()
N, d1, d2 = np_X.shape
X = np_X.reshape(N, d1*d2)
pi, mu, S, loss = GMM(X, 5)
print(pi)
print(mu)
print(S)
print(loss)

0th pi: [0.19356424 0.25224431 0.21259213 0.19217803 0.14942128]
0th S: [[1.0488136  1.21518947 1.10276348 ... 0.92690446 1.34285499 1.31803341]
 [0.60241386 0.65638345 0.80419879 ... 0.84010085 0.56859693 0.7289077 ]
 [0.85798404 0.93514209 1.09092683 ... 0.50090339 0.5698144  0.72649138]
 [0.98110206 0.75152284 1.37668199 ... 1.33130376 0.66284915 1.25190614]
 [1.17071006 1.39068749 1.46879135 ... 0.62987038 1.17056602 0.96383263]]
0th -1.6493330935541435
[ 126.00000069  812.         3168.99999931  790.         1026.        ]
1th pi: [0.021273   0.13709269 0.53503292 0.13337836 0.17322303]
1th S: [[1.e-07 1.e-07 1.e-07 ... 1.e-07 1.e-07 1.e-07]
 [1.e-07 1.e-07 1.e-07 ... 1.e-07 1.e-07 1.e-07]
 [1.e-07 1.e-07 1.e-07 ... 1.e-07 1.e-07 1.e-07]
 [1.e-07 1.e-07 1.e-07 ... 1.e-07 1.e-07 1.e-07]
 [1.e-07 1.e-07 1.e-07 ... 1.e-07 1.e-07 1.e-07]]
1th -2.014565410815898


  log_r[:,k] = np.log(pi[k]) - 0.5 * np.sum(np.log(S[k])) - 0.5 * np.dot((X-mu[k]) ** 2, 1/S[k])


[nan nan nan nan nan]
2th pi: [nan nan nan nan nan]
2th S: [[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
2th nan
[nan nan nan nan nan]
3th pi: [nan nan nan nan nan]
3th S: [[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
3th nan
[nan nan nan nan nan]
4th pi: [nan nan nan nan nan]
4th S: [[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
4th nan
[nan nan nan nan nan]
5th pi: [nan nan nan nan nan]
5th S: [[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
5th nan
[nan nan nan nan nan]
6th pi: [nan nan nan nan nan]
6th S: [[nan nan nan ... nan nan nan]
 [nan nan nan ... 

In [43]:
X = np.loadtxt(open("gmm_dataset.csv", "rb"), delimiter=",")

K = 10
pis, mus, Ss, losses = [[None for _ in range(K)] for _ in range(4)]
for k in range(10):
    pi, mu, S, loss = GMM(X, k+1)
    pis[k] = pi
    mus[k] = mu
    Ss[k] = S
    losses[k] = [ val for val in loss if val > 0.0]
    print(pi)


[1.]
[0.7006095 0.2993905]
[0.39550385 0.09995025 0.5045459 ]
[0.39981342 0.09993414 0.20025588 0.29999657]
[0.09965521 0.09999844 0.19999783 0.3000009  0.30034762]
[9.96553183e-02 9.99984421e-02 1.99997834e-01 3.00000900e-01
 3.00147506e-01 2.00000000e-04]
[9.96554813e-02 9.99984424e-02 1.49279474e-01 2.99999062e-01
 3.00147576e-01 2.00000000e-04 5.07199639e-02]
[9.88494429e-02 9.99984500e-02 1.48502976e-01 2.99999017e-01
 3.00153354e-01 2.00000000e-04 5.14968557e-02 7.99904608e-04]
[9.88497978e-02 9.99984611e-02 1.48125087e-01 2.99798994e-01
 3.00152972e-01 2.00000000e-04 5.18746937e-02 5.99995193e-04
 3.99999808e-04]
[9.88502760e-02 9.99984621e-02 1.48125473e-01 2.99798994e-01
 2.99752492e-01 2.00000000e-04 5.18743079e-02 5.99995193e-04
 3.99999808e-04 3.99999569e-04]


In [44]:
for loss in losses:
    print(loss[-1])

69939.42104205661
58105.141162558255
51670.14515830016
41554.08327275577
34512.132311872665
34354.02638227194
34313.23906612599
34280.77315290888
34236.87015535322
34210.192847969694


In [45]:
transform = transforms.Compose([transforms.ToTensor()])


In [39]:
pis, mus, Ss, losses = [[None for _ in range(10)] for _ in range(4)]

for identify_cls in range(10):
    print("train for class {}".format(identify_cls))
    data_train = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)
    idx = data_train.targets == identify_cls
    np_X = data_train.data[idx].numpy()
    N, d1, d2 = np_X.shape
    X = np_X.reshape(N, d1*d2)
    pi, mu, S, loss = GMM(X, 5)
    pis[k] = pi
    mus[k] = mu
    Ss[k] = S
    losses[k] = [ val for val in loss if val > 0.0]

train for class 0


  r = r / r_total[:,None]
 30%|███       | 15/50 [00:01<00:04,  8.17it/s]


KeyboardInterrupt: 

In [48]:

data_train = datasets.MNIST(root = "./data/",
                        transform=transform,
                        train = True,
                        download = True)
idx = data_train.targets == 0
np_X = data_train.data[idx].numpy()
N, d1, d2 = np_X.shape
X = np_X.reshape(N, d1*d2)
pi, mu, S, loss = GMM(X, 5)
print(pi)
print(mu)
print(S)
print(loss)

  log_r[:,k] = np.log(pi[k]) - 0.5 * np.sum(np.log(S[k])) - 0.5 * np.dot((X-mu[k]) ** 2, 1/S[k])


[nan nan nan nan nan]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
[24574538708.09118, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]


In [37]:
X = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
mu = np.mean(X, axis=0)
print(mu)
print(X-mu)
np.sum((X - mu) ** 2, axis=0)

[5.5 6.5 7.5]
[[-4.5 -4.5 -4.5]
 [-1.5 -1.5 -1.5]
 [ 1.5  1.5  1.5]
 [ 4.5  4.5  4.5]]


array([45., 45., 45.])

In [38]:
X + mu

array([[ 6.5,  8.5, 10.5],
       [ 9.5, 11.5, 13.5],
       [12.5, 14.5, 16.5],
       [15.5, 17.5, 19.5]])