In [7]:
import numpy as np
import torchvision.datasets as datasets
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
import torch.distributions as tdist
import scipy.stats as st
from scipy.special import logsumexp

In [8]:
a = torch.tensor([1.0,2.0,3,4,5,6,7,8,9])
b = torch.tensor([0.0,0,0,4,5,6,0,0,0])
x = torch.rand(784)
c = torch.rand(2, 784,784)

# t1 = st.multivariate_normal(a,8)
# t2 = tdist.MultivariateNormal(a, torch.diag_embed(b))
# print(t1.logpdf(x), t2.log_prob(x))


In [9]:
# Get train and test data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)

In [10]:
def pca(data, out_dim=10):
    data = data.view(-1,784)
    data_mean = data.mean(0)
    Q = (data-data_mean).T@(data-data_mean)
    U, S, V = torch.svd(Q)
    data = data @ U[torch.argsort(S,descending=True)[:out_dim]].T
    return data


In [11]:
class GMM():
    def __init__(self, latentspace, input_space=784):
        super().__init__()
        self.C = latentspace
        self.N = input_space
        self.pi = None
        self.mu = []
        self.cov = []

    def init(self):
        x = torch.rand(self.C)
        self.pi = x/x.sum()
        for i in range(self.C):
            self.mu.append(torch.rand(self.N))
            self.cov.append(torch.diag_embed(torch.rand(self.N)))

    def e_stemp(self, data):
        gamma = []
        for k in range(self.C):
            pdf = st.multivariate_normal(self.mu[k], self.cov[k])
            gamma.append(torch.log(self.pi[k]) + pdf.logpdf(data))
        gamma = torch.stack(gamma)
        gamma_z = torch.exp(gamma-torch.logsumexp(gamma, axis=0))
        torch.testing.assert_close(
            gamma_z.sum(), torch.tensor(data.shape[0], dtype=gamma_z.dtype))
        return gamma_z

    def m_step(self, data, gamma_z):
        N_k = gamma_z.sum(1)
        log_mu = (torch.log(gamma_z).unsqueeze(
            2) + torch.log(data.unsqueeze(0))).logsumexp(dim=1) - N_k.unsqueeze(1)
        
        self.mu = torch.exp(log_mu)
        self.cov = torch.exp(torch.logsumexp(torch.log(gamma_z).view(self.C, data.shape[0], 1, 1)+torch.log(torch.einsum(
            'ijk, ijg -> ijkg', data-self.mu.unsqueeze(1), data-self.mu.unsqueeze(1))), dim=1) - N_k.view(-1, 1, 1))+1e-4*torch.eye(self.N)
  
        # self.cov = (((data-self.mu.unsqueeze(1))**2 * gamma_z.unsqueeze(2)).sum(1))/N_k.unsqueeze(1)+1e-4


        self.pi = N_k/N_k.sum()

        # for k in range(self.C):
        #     # self.mu[k] = (gamma_z[k]*data.T).sum(1)/N_k[k]
        #     # self.cov[k] = torch.exp(torch.log(gamma_z[k].reshape(data.shape[0], 1, 1))+torch.log(torch.einsum(
        #     #     'ij, il -> ijl', (data - self.mu[k]), (data - self.mu[k])))-torch.log(N_k[k])).sum(0)
        #     # self.pi[k] = N_k[k]/N_k.sum()
        #     temp_mu = 0
        #     for n, d in enumerate(data):
        #         temp_mu += gamma_z[k][n]*d
        #         assert not torch.isnan(
        #             gamma_z[k][n]*d).any(), print(gamma_z[k][n]*d)
        #     self.mu[k] = temp_mu/N_k[k]
        #     assert not torch.isnan(self.mu[k]).any(), print(self.mu[k])
        #     temp_cov = 0
        #     for n, d in enumerate(data):
        #         temp_cov += gamma_z[k][n] * \
        #             torch.outer(d-self.mu[k], d-self.mu[k])
        #     self.cov[k] = temp_cov/N_k[k]
        #     self.pi[k] = N_k[k]/N_k.sum()

    def log_likelihood(self, data):
        gamma = []
        for k in range(self.C):
            pdf = st.multivariate_normal(self.mu[k], self.cov[k])
            gamma.append(torch.log(self.pi[k]) + pdf.logpdf(data))
        gamma = torch.stack(gamma).flatten(0)
        return torch.logsumexp(gamma, dim=0)

    def fit(self, data):
        gamma_z = self.e_stemp(data)
        self.m_step(data, gamma_z)
        return self.log_likelihood(data)


In [12]:
gmm = GMM(2,10)
gmm.init()
with torch.no_grad():
    for i in range(10):
        ll_list = []
        for j,(data,_) in enumerate(train_loader):
            data = pca(data,10)+1e-4
            ll = gmm.fit(data)
            ll_list.append(ll)
        print("epoch : ",i,  "  likelihood : ", np.mean(ll_list))

epoch :  0   likelihood :  37.24530428935217
epoch :  1   likelihood :  40.023783720592206
epoch :  2   likelihood :  38.5708325029266
epoch :  3   likelihood :  40.0139915135312
epoch :  4   likelihood :  34.28838873428781
epoch :  5   likelihood :  40.10558142627226
epoch :  6   likelihood :  38.309020173636625
epoch :  7   likelihood :  37.35500464891938
epoch :  8   likelihood :  39.9635536575236
epoch :  9   likelihood :  40.030559973814235
