In [12]:
import numpy as np
import matplotlib.pyplot as plt
import torch

from torch.distributions.multivariate_normal import MultivariateNormal

In [24]:
def gauss(x, mean, sigma):
  exponent = (x - mean)
  exponent =-(1/2)* torch.diagonal(exponent @ torch.inverse(sigma) @ exponent.T)
  norm = (2*torch.pi)**(1/2) * torch.det(sigma)
  return (1/norm)*torch.exp(exponent)

In [38]:
def gaussian_mixture(x, k, iter=0, cus= None):
  #initialize
  #means = random points
  #covar = unit variance
  #pi = uniform based on k

  perm = torch.randperm(x.size(0))
  idx = perm[:k]
  means = x[idx]
  if cus != None:
    means = cus
  covar = torch.eye(x.size(1), dtype=torch.float).repeat(k,1,1)
  pi = torch.ones(k, dtype=torch.float)/k
  gammas = torch.zeros((x.size(0), k), dtype=torch.float)
  nk = torch.zeros((k), dtype=torch.float)

  for i in range(k):
    gauss = MultivariateNormal(means[i], covar[i])
    gammas[:,i] = pi[i] * torch.exp(gauss.log_prob(x))

  loglike = torch.sum(torch.log(torch.sum(gammas, 1)))

  print(loglike)
  for it in range(iter):
    # E step   
    gammas = gammas / torch.sum(gammas, 1).reshape(-1,1)
    nk = torch.sum(gammas, 0)

    # M step
    for i in range(k):
      means[i] = (1/nk[i]) * torch.sum(x*gammas[:,[i]], 0)
      covar[i] = (1/nk[i]) * (gammas[:,[i]]* (x - means[i])).T @  (x - means[i])
      pi[i] = nk[i] / x.shape[0]

    # eval step
    for i in range(k):
      gauss = MultivariateNormal(means[i], covar[i])
      gammas[:,i] = pi[i] * torch.exp(gauss.log_prob(x))
    loglike = torch.sum(torch.log(torch.sum(gammas, 1)))

    print(loglike)

  gammas = gammas / torch.sum(gammas, 1).reshape(-1,1)
  return means, covar, pi, gammas

In [39]:
mean_1 = [3,6]
mean_2 = [2,2]
mean_3 = [6,3]

std_1 = [0.5,2]
std_2 = [0.5,1]
std_3 = [0.5,0.5]

data_points = 200

normal_1 = torch.normal(mean=torch.tensor(mean_1*data_points, dtype=torch.float), std = torch.tensor(std_1*data_points, dtype=torch.float)).reshape(data_points,-1)
normal_2 = torch.normal(mean=torch.tensor(mean_2*data_points, dtype=torch.float), std = torch.tensor(std_2*data_points, dtype=torch.float)).reshape(data_points,-1)
normal_3 = torch.normal(mean=torch.tensor(mean_3*data_points, dtype=torch.float), std = torch.tensor(std_3*data_points, dtype=torch.float)).reshape(data_points,-1)

In [40]:
data = torch.cat([normal_1, normal_2, normal_3])

In [48]:
m, c, p, g = gaussian_mixture(data, 3, iter=10)

tensor(-5276.0586)
tensor(-2377.1118)
tensor(-2194.2305)
tensor(-1986.4454)
tensor(-1947.0319)
tensor(-1940.6316)
tensor(-1935.0790)
tensor(-1926.8478)
tensor(-1913.5624)
tensor(-1903.7061)
tensor(-1898.7549)


In [49]:
import torchvision
import torchvision.transforms as transforms
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                            train=True, 
                                       transform=transforms.ToTensor(),  
                                           download=True)

In [50]:
x = torch.empty(0)
y = torch.empty(0)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                            batch_size=1024, 
                                            shuffle=True)
for xs,i in train_loader:
  x = torch.cat([x,xs])
  y = torch.cat([y,i])

In [51]:
x_reshape = x.reshape(-1, 28*28)

In [52]:
x_reshape[0].shape

torch.Size([784])

In [39]:
m, c, p, g = gaussian_mixture(x_reshape, 10, iter=10)

KeyboardInterrupt: 