# A vanilla VAE implementation for MNIST

In [1]:
import os
import sys
import time
import random
import tempfile

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.distributions import MultivariateNormal

In [2]:
%matplotlib inline
plt.ioff()

In [3]:
latent_dim = 10
dim = 784

dev = torch.device('cpu')

## Load MNIST:

In [4]:
def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


def tensor_round(tensor):
    return torch.round(tensor)

In [5]:
mnist_dataset = datasets.MNIST(tempfile.gettempdir(), train=True, download=True,
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
                            transforms.Lambda(lambda tensor:tensor_round(tensor))
                        ]))

## Define functions:

In [6]:
def make_encoder(dim, latent_dim):
    encoder = nn.Sequential(
        nn.Linear(dim, 500),
        nn.Tanh(),
        nn.Linear(500, 2*latent_dim),
    )
    
    return encoder

In [7]:
def make_decoder(dim, latent_dim):
    decoder = nn.Sequential(
        nn.Linear(latent_dim, 500),
        nn.Tanh(),
        nn.Linear(500, dim),
        nn.Sigmoid(),
    )
    
    return decoder

In [8]:
def compute_KL_divergence(mu, log_var):
    KL = -torch.mean(0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var), dim=1))
    
    return KL

In [9]:
def compute_reconstruction_error(x, x_p):
    return -torch.mean(x*torch.log(x_p))

In [10]:
fixed_z = np.random.randn(16, latent_dim)
fixed_z = torch.Tensor(fixed_z, device=dev)
def sample_digits(model):
    imas = model(fixed_z).detach().numpy()
    imas = imas.reshape([16, 28, 28])
    return imas

In [23]:
def create_animation_digits(intermediate_results):
    
    fig, ax_arr = plt.subplots(4, 4, figsize=(18, 18))
    
    imas = intermediate_results[0]
    
    for i in range(16):
        ax_arr[i//4, i%4].imshow(imas[i], cmap='gray', vmin=0.0, vmax=1.0)
        ax_arr[i//4, i%4].axis('off')
    
    
    def update_scat(i):
        imas = intermediate_results[i]
        for i in range(16):
            ax_arr[i//4, i%4].imshow(imas[i], cmap='gray', vmin=0.0, vmax=1.0)
            ax_arr[i//4, i%4].axis('off')

    ani = animation.FuncAnimation(fig, update_scat, frames=len(intermediate_results), interval=30)
    plt.close(fig)
    
    return ani

## Train VAE:

In [12]:
batch_size = 100
epochs = 5

In [13]:
base_distr = MultivariateNormal(torch.zeros(latent_dim), torch.eye(latent_dim))

In [14]:
encoder, decoder = make_encoder(dim, latent_dim), make_decoder(dim, latent_dim)

In [15]:
train_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

In [16]:
optimizer = torch.optim.Adam([{'params': encoder.parameters()},
                              {'params': decoder.parameters()}
                             ], lr=1e-5)

In [17]:
loss = []
logprior = []
logdet = []


intermediate_results = []

# Train loop
t0 = time.time()
for e in range(epochs):

    cum_loss = torch.zeros(1, device=dev)
    count = 0
    for images, _ in train_loader:
        images = images.view(images.shape[0], -1)

        images = images.to(dev, non_blocking=True)

        optimizer.zero_grad()

        infer = encoder(images)
        
        mu, log_var = infer[:, :latent_dim], infer[:, latent_dim:]
        
        eps = base_distr.sample((batch_size,))
        
        z = mu + torch.sqrt(torch.exp(log_var)) * eps
        
        x_p = decoder(z)
        
        NLL = compute_reconstruction_error(images, x_p)
        KL = compute_KL_divergence(mu, log_var)
        
        _loss = NLL + KL


        cum_loss += _loss
        count += 1

        _loss.backward()
        optimizer.step()

    loss.append(cum_loss.item()/count)
    intermediate_results.append(sample_digits(decoder))

    if e%5 == 4:
        print('epoch: {}, at time: {:.2f}, loss: {:.3f}'.format(e, time.time()-t0, loss[-1]))

epoch: 4, at time: 197.66, loss: 0.029


In [18]:
ima = x_p.detach().numpy()[0]

In [19]:
plt.imshow(ima.reshape(28, 28))

<matplotlib.image.AxesImage at 0x1d4a7e603c8>

In [20]:
ima

array([0.47942612, 0.46537244, 0.45356458, 0.55653423, 0.48218894,
       0.47628957, 0.5331422 , 0.5026423 , 0.5055079 , 0.6135903 ,
       0.50283885, 0.4845178 , 0.5097339 , 0.43261817, 0.44950885,
       0.5204032 , 0.47666147, 0.5748963 , 0.51903695, 0.565524  ,
       0.41957557, 0.50138015, 0.5895015 , 0.49672297, 0.49020857,
       0.4680981 , 0.4492906 , 0.4754856 , 0.4145845 , 0.5802914 ,
       0.51635844, 0.541656  , 0.5291821 , 0.4837913 , 0.5765901 ,
       0.5398134 , 0.50033486, 0.56109977, 0.6168123 , 0.56650305,
       0.52689743, 0.58848006, 0.58789456, 0.53260607, 0.5504477 ,
       0.5920115 , 0.56833464, 0.6302428 , 0.5812208 , 0.41271177,
       0.5041862 , 0.48066127, 0.5230881 , 0.41480753, 0.53949165,
       0.4439067 , 0.55645025, 0.50793827, 0.48971385, 0.4154917 ,
       0.47229335, 0.4587353 , 0.5429384 , 0.50884455, 0.6425105 ,
       0.4938033 , 0.60024875, 0.7532779 , 0.6636206 , 0.6984115 ,
       0.77232486, 0.7139251 , 0.70707226, 0.7553589 , 0.70002

In [24]:
ani = create_animation_digits(intermediate_results)

In [25]:
HTML(ani.to_html5_video())