In [None]:
%matplotlib inline

Attribution: 

   * Most material is adapted from [Agustinus Kristiadi's blog](https://wiseodd.github.io/techblog/2017/01/24/vae-pytorch/) and [Generative models repository](https://github.com/wiseodd/generative-models)

# Variational Autoencoder

We're now doing to use PyTorch to implement a *vanilla* version of one of the most popular generative models today, the VAE.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import os

import torchvision
import torch.utils.data as Data

In [None]:
mnist_train = torchvision.datasets.MNIST(
    root='./data/',
    train=True,                                     # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalizes in the range [0.0, 1.0]
    download=True,                                  # download it if you don't have it
)

mb_size = 64
Z_dim = 100
n_r = mnist_train.train_data.shape[1]
n_c = mnist_train.train_data.shape[2]
X_dim =  n_r * n_c
h_dim = 128
lr = 1e-3
c = 0
n_samples = mb_size # used for visualization during learning

assert(n_samples <= mb_size)  # this is assumed below

train_loader = Data.DataLoader(dataset=mnist_train, batch_size=mb_size,
                               shuffle=True)

## Meet the Encoder and Decoder


Like the GAN, a VAE has two parts. The Decoder $P(X|z)$ is analagous to the Generator in the GAN.

The Encoder $Q(z|X)$ is used for approximate inference.

Let's start by constructing the Encoder:

In [None]:
# Again use Xavier initialization
Wxh = torch.empty(X_dim, h_dim)
nn.init.xavier_normal_(Wxh)
Wxh = Variable(Wxh, requires_grad=True)
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Whz_mu = torch.empty(h_dim, Z_dim)
nn.init.xavier_normal_(Whz_mu)
Whz_mu = Variable(Whz_mu, requires_grad=True)
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

Whz_var = torch.empty(h_dim, Z_dim)
nn.init.xavier_normal_(Whz_var)
Whz_var = Variable(Whz_var, requires_grad=True)
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)


def Q(X):
    h = F.relu(X.mm(Wxh) + bxh)
    z_mu = h.mm(Whz_mu) + bhz_mu
    z_var = h.mm(Whz_var) + bhz_var
    return z_mu, z_var


def sample_z(mu, log_var):
    eps = Variable(torch.randn(mb_size, Z_dim))
    return mu + torch.exp(log_var / 2) * eps

Note that `sample_z' is implementing the reparameterization trick.

Now, we define the Decoder:

In [None]:
Wzh = torch.empty(Z_dim, h_dim)
nn.init.xavier_normal_(Wzh)
Wzh = Variable(Wzh, requires_grad=True)
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = torch.empty(h_dim, X_dim)
nn.init.xavier_normal_(Whx)
Whx = Variable(Whx, requires_grad=True)
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

def P(z):
    h = F.relu(z.mm(Wzh) + bzh)
    X = torch.sigmoid(h.mm(Whx) + bhx)
    return X

Set up the optimizer:

In [None]:
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

def reset_grad():
    for p in params:
        if p.grad is not None:
            p.grad.data.zero_()

## Step-by-step

We'll wrap everything up in a training loop soon, but for now, let's walk through a single iteration.

First, we'll grab a batch of data for demonstration purposes:

In [None]:
# data as input to the discriminator
X, _ = next(iter(train_loader))
X = Variable(X.view(-1, X_dim))

Now, push that data through our Encoder to get the means and variances of $z$ for the batch:

In [None]:
z_mu, z_var = Q(X)
print(z_mu.size())
print(z_var.size())

Then sample a batch of $z$s:

In [None]:
z = sample_z(z_mu, z_var)
print(z.size())

Now, take the batch of sampled $z$s and push them through the Decoder to get $X$ (note that it isn't sampled, we are just returning the mean computed by the Decoder):

In [None]:
X_sample = P(z)

We won't even bother to show these, because the learning hasn't started yet so we would just see noise.

With the encoding and decoding done, we can compute the two loss terms:

In [None]:
recon_loss = F.binary_cross_entropy(X_sample, X, reduction='sum') / mb_size
kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
loss = recon_loss + kl_loss

Now that we've computed the loss, we can backprop through the decoder and the encoder, thanks to the reparameterization trick, and then apply the parameter update:

In [None]:
loss.backward()
solver.step()

As usual in PyTorch, we need to clear gradients so that they don't accumulate:

In [None]:
reset_grad()

## Putting it together in a learning loop

Alright, let's put these steps together and see how it performs.

In [None]:
for it in range(100000):
    X, _ = next(iter(train_loader))
    X = Variable(X.view(-1, X_dim))

    # Forward
    z_mu, z_var = Q(X)
    z = sample_z(z_mu, z_var)
    X_sample = P(z)

    # Loss
    recon_loss = F.binary_cross_entropy(X_sample, X, reduction='sum') / mb_size
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
    loss = recon_loss + kl_loss

    # Backward
    loss.backward()

    # Update
    solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.item()))

        samples = P(z).data.numpy()[:n_samples]

        n_gs = int(np.sqrt(n_samples))
        
        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(n_gs, n_gs)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out_vae/'):
            os.makedirs('out_vae/')

        plt.savefig('out_vae/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)