In [131]:
import torch
import torch.optim as optim
import tqdm.notebook as tqdm

from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR


import os
import sys
sys.path.append('/Users/Matt/projects/entrovae/')
sys.path.append('/Users/Matt/projects/sgpvae/')

import sgpvae
import entrovae

## Data preparation.

In [2]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST('../data', train=True, download=True,
                               transform=transform)
test_dataset = datasets.MNIST('../data', train=False,
                              transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

## Classifier training.

In [3]:
def train_cls(model, train_loader, optimiser, epoch):
    model.train()
    batch_iter = tqdm.tqdm(enumerate(train_loader), desc='Batch')
    for batch_idx, (x, y) in batch_iter:
        optimiser.zero_grad()
        loss, _ = model.nll(x, y)
        loss.backward()
        optimiser.step()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def test_cls(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in test_loader:
            loss, output = model.nll(x, y)
            test_loss += loss.item()
            
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(y.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [6]:
cls = entrovae.classifiers.MNISTClassificationNet()
cls_optimiser = optim.Adadelta(cls.parameters(), lr=1.0)
cls_scheduler = StepLR(cls_optimiser, step_size=1, gamma=0.7)

In [10]:
for epoch in range(1, 14+1):
    train_cls(cls, train_loader, cls_optimiser, epoch)
    test_cls(cls, test_loader)
    cls_scheduler.step()

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…



Test set: Average loss: 0.0001, Accuracy: 9821/10000 (98%)



HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

## Set GMMVAE datasets. 

In [12]:
pred_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)
cls_output = torch.zeros(len(train_dataset), 10)
with torch.no_grad():
    for batch_idx, (x, y) in enumerate(pred_loader):
        cls_output[batch_idx*64:(batch_idx+1)*64] = cls(x).detach().exp()

In [13]:
class GMMVAEDataset(torch.utils.data.Dataset):
    def __init__(self, x, cls_output):
        self.x = x
        self.cls_output = cls_output
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        return self.x[idx, ...], self.cls_output[idx]
    
gmmvae_dataset = GMMVAEDataset(train_dataset.data, cls_output)
gmmvae_loader = torch.utils.data.DataLoader(gmmvae_dataset, batch_size=64)

In [116]:
from torch import autograd

def train_gmmvae(model, loader, optimiser, epoch):
    model.train()
    train_loss = 0
    batch_iter = tqdm.tqdm(enumerate(loader), desc='Batch')
    
    for batch_idx, (x, pi) in batch_iter:
        optimiser.zero_grad()
        
        x = x.view(-1, 784).float() / 255
        
        loss = -model.elbo(x, pi)
        loss.backward()
        
        for name, param in model.named_parameters():
            if (param.grad != param.grad).any():
                pdb.set_trace()
                print('wtf')
        
        optimiser.step()
        
        train_loss += loss.item()        
        batch_iter.set_postfix(loss=loss.item())
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(loader.dataset)))

In [128]:
z_dim = 2
x_dim = 784

encoder = sgpvae.networks.LinearGaussian(x_dim, z_dim, [512, 256], min_sigma=1e-3)
loglikelihood = entrovae.loglikelihoods.NNBernoulli(z_dim, x_dim, [256, 512])

gmmvae_model = GMMVAE(encoder, loglikelihood, z_dim, 10)

## GMMVAE training.

In [129]:
gmmvae_optimiser = optim.Adam(gmmvae_model.parameters())
for epoch in range(1, 14+1):
    train_gmmvae(gmmvae_model, gmmvae_loader, gmmvae_optimiser, epoch)

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 1 Average loss: 2.8620


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 2 Average loss: 2.4977


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 3 Average loss: 2.4022


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…




KeyboardInterrupt: 

In [132]:
with torch.no_grad():            
    sample = gmmvae_model.sample(num_samples=100)
    
    filename = './samples/gmmvae_sample'
    
    if os.path.exists(filename + '.png'):
        i = 1
        while os.path.exists(filename + '_' + str(i) + '.png'):
            i += 1
            
        filename = filename + '_' + str(i) + '.png'
        
    else:
        filename = filename + '.png'
    
    save_image(sample.view(100, 1, 28, 28), filename)

In [126]:
import torch.nn as nn

from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
from torch.distributions.kl import kl_divergence


class GMMVAE(nn.Module):

    def __init__(self, encoder, loglikelihood, z_dim, k):
        super().__init__()

        self.encoder = encoder
        self.loglikelihood = loglikelihood
        self.z_dim = z_dim
        self.k = k

        # Initialise GMM parameters.
        self.pz_y_mu = nn.Parameter(torch.randn((k, z_dim)),
                                    requires_grad=True)
        self.pz_y_logsigma = nn.Parameter(torch.zeros((k, z_dim)),
                                          requires_grad=True)

    def qz(self, x):
        qz_mu, qz_sigma = self.encoder(x)
        qz = Normal(qz_mu, qz_sigma)

        return qz

    def py_z(self, z, pi):
        # Compute the marginal likelihood, p(z) = \sum_k p(z|y)p(y).
        pzy = torch.zeros_like(pi)
        for k in range(self.k):
            pz_y = Normal(self.pz_y_mu[k, :], self.pz_y_logsigma[k, :].exp())
            pzy[:, k] = pz_y.log_prob(z).sum(1)
            pzy[:, k] += pi[:, k].log()

        pz = torch.logsumexp(pzy, dim=1)

        # Compute the posterior p(y|z) = p(z, y) / p(z)
        py_z = pzy - pz.unsqueeze(1)
        py_z = Categorical(py_z.exp())

        return py_z

    def elbo(self, x, pi, num_samples=1):
        """Monte Carlo estimate of the evidence lower bound."""
        qz = self.qz(x)

        # z_samples is shape (num_samples, batch, z_dim).
        z_samples = qz.rsample((num_samples,))

        log_px_z = 0
        kl_y = 0
        kl_z = 0
        for z in z_samples:
            log_px_z += self.loglikelihood(z, x).sum()

            py_z = self.py_z(z, pi)
            kl_y += kl_divergence(py_z, Categorical(pi)).sum()

            for k in range(self.k):
                pz_y = Normal(
                    self.pz_y_mu[k, :].repeat(x.shape[0], 1), 
                    self.pz_y_logsigma[k, :].exp().repeat(x.shape[0], 1))
                
                kl_z_k = py_z.probs[:, k] * kl_divergence(qz, pz_y).sum(1)
                kl_z += kl_z_k.sum()

        log_px_z /= num_samples
        kl_y /= num_samples
        kl_z /= num_samples
        elbo = (log_px_z - kl_y - kl_z) / x.shape[0]

        return elbo

    def sample(self, pi=None, num_samples=1):
        if pi is None:
            pi = torch.ones(self.k) / self.k

        # Sample p(y).
        py = Categorical(pi)
        y = py.sample((num_samples,))

        # Sample p(z|y).
        pz_y = Normal(self.pz_y_mu[y, :], self.pz_y_logsigma[y, :].exp())
        z = pz_y.sample()

        # Sample p(x|z).
        samples = self.loglikelihood.predict(z)

        return samples

    def predict_x(self, z):
        x = self.loglikelihood.predict(z)

        return x

    def reconstruct_x(self, x):
        z, _ = self.encoder(x)
        x_recon = self.loglikelihood.predict(z)

        return x_recon