# A vanilla VAE implementation for MNIST

In [1]:
import os
import sys
import time
import random
import tempfile

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.distributions import MultivariateNormal

## Load MNIST:

In [2]:
def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


def tensor_round(tensor):
    return torch.round(tensor)

In [3]:
mnist_dataset = datasets.MNIST(tempfile.gettempdir(), train=True, download=True,
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
                            transforms.Lambda(lambda tensor:tensor_round(tensor))
                        ]))

## Define functions:

In [4]:
def make_encoder(dim, latent_dim):
    encoder = nn.Sequential(
        nn.Linear(dim, 500),
        nn.Tanh(),
        nn.Linear(500, 2*latent_dim),
    )
    
    return encoder

In [5]:
def make_decoder(dim, latent_dim):
    decoder = nn.Sequential(
        nn.Linear(latent_dim, 500),
        nn.Tanh(),
        nn.Linear(500, dim),
        nn.Sigmoid(),
    )
    
    return decoder

In [6]:
def compute_KL_divergence(mu, log_var):
    KL = torch.mean(0.5 * torch.sum(1 + log_var + mu**2 + torch.exp(log_var), dim=1))
    
    return KL

In [7]:
lossf = nn.BCELoss()

def compute_reconstruction_error(x, x_p):
    return lossf(x_p, x)

## Train VAE:

In [8]:
batch_size = 100
latent_dim = 10
dim = 784
epochs = 100

dev = torch.device('cpu')

In [9]:
base_distr = MultivariateNormal(torch.zeros(latent_dim), torch.eye(latent_dim))

In [10]:
encoder, decoder = make_encoder(dim, latent_dim), make_decoder(dim, latent_dim)

In [11]:
train_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

In [12]:
optimizer = torch.optim.Adam([{'params': encoder.parameters()},
                              {'params': decoder.parameters()}
                             ], lr=1e-5)

In [13]:
loss = []
logprior = []
logdet = []

# Train loop
t0 = time.time()
for e in range(epochs):

    cum_loss = torch.zeros(1, device=dev)
    
    for images, _ in train_loader:
        images = images.view(images.shape[0], -1)

        images = images.to(dev, non_blocking=True)

        optimizer.zero_grad()

        infer = encoder(images)
        
        mu, log_var = infer[:, :latent_dim], infer[:, latent_dim:]
        
        eps = base_distr.sample((batch_size,))
        
        z = mu + torch.sqrt(torch.exp(log_var)) * eps
        
        x_p = decoder(z)
        
        NLL = compute_reconstruction_error(images, x_p)
        KL = compute_KL_divergence(mu, log_var)
        
        _loss = NLL - KL


        cum_loss += _loss

        _loss.backward()
        optimizer.step()

    loss.append(cum_loss.item()/len(train_loader.dataset))

    if e%10 == 9:
        print('epoch: {}, at time: {:.2f}, loss: {:.3f}'.format(e, time.time()-t0, loss[-1]))

epoch: 9, at time: 481.30, loss: -14014025099674363904.000


KeyboardInterrupt: 