In [1]:
# Let's build a Gaussian Mixture Variational Autoencoder (GMVAE) in PyTorch
# We'll use the MNIST dataset
# We'll use a Mixture of Gaussians as prior and Gaussian posterior

# Import libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
from torch.distributions import Normal, Multinomial, kl_divergence


The generative model is $p_{\beta, \theta} = p(w)p(z)p_{\beta}(x|w,z)p_{\theta}(y|x)$

where: \
$w \sim N(0, Id)$ \
$z \sim Mult(\pi)$ \
$x|z,x \sim \prod_{k=1}^{K} N(\mu_{z_k})(w;beta), diag(\sigma^2_{z_k})(w;\beta)))^{z_k}$ \
$y|x \sim N(\mu(x;\theta), diag(\sigma^2(x; \theta)))$

where $\mu_{z_k}(,\beta), \ \sigma^2_{z_k}(,\beta), \ \mu(,\theta), \ \sigma^2(,\theta)$ are given by neural networks with parameters beta and theta respectively.
That is, the observed
sample y is generated from a neural network observation model parametrised by theta and the contin- uous latent variable x. Furthermore, the distribution of x|w is a Gaussian mixture with means and variances specified by another neural network model parametrised by beta and with input w.

In [35]:
# define the model

class GMVAE(nn.Module):
    def __init__(self, INPUT_DIM, H_DIM, W_DIM, K):
        super(GMVAE, self).__init__()
        self.INPUT_DIM = INPUT_DIM
        self.H_DIM = H_DIM
        self.W_DIM = W_DIM
        self.K = K

        # encoder
        self.img_2hid = nn.Linear(self.INPUT_DIM, self.H_DIM)
        self.hid_2mu = nn.Linear(self.H_DIM, self.W_DIM*K)
        self.hid_2sigma = nn.Linear(self.H_DIM, self.W_DIM*K)
        self.hid_2pi = nn.Linear(self.H_DIM, self.K)

        # decoder
        self.z_2hid = nn.Linear(self.W_DIM, self.H_DIM)
        self.hid_2img = nn.Linear(self.H_DIM, self.INPUT_DIM)

        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.img_2hid(x))
        mu = self.relu(self.hid_2mu(h))
        mu = mu.clone().detach().reshape(self.K, self.W_DIM)
        sigma = torch.exp(self.relu(self.hid_2sigma(h)))
        sigma = sigma.clone().detach().reshape(self.K, self.W_DIM)
        pi = self.relu(self.hid_2pi(h))
        return mu, sigma, pi

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))
   
    def forward(self, x):
        mu, sigma, pi = self.encode(x)
        zs = [(torch.tensor([pi[i]]*self.W_DIM) * self.reparametrize(mu[i], sigma[i])) for i in range(self.K)]
        z = torch.sum(torch.stack(zs), dim=0)
        x_hat = self.decode(z)
        return x_hat, mu, sigma, pi

    def reparametrize(self, mu, sigma):
        eps = torch.randn_like(sigma)
  
        return mu + eps * sigma








x_hat shape: torch.Size([784])
mu shape: torch.Size([10, 2])
sigma shape: torch.Size([10, 2])
pi shape: torch.Size([10])


In [33]:
# training loop         

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.to(device)
        optimizer.zero_grad()
        x_hat, mu, sigma, pi = model(x.view(-1, INPUT_DIM))
        loss = loss_function(x_hat, x.view(-1, INPUT_DIM), mu, sigma, pi)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(x)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))