# A Very Brief Overview of Variational Autoencoders in pytorch (WIP)

This notebook implements a Variational Autoencoder (VAE) in pytorch. VAEs are described in detail in this [tutorial](https://arxiv.org/abs/1606.05908). This is an implementation of a VAE as described in the paper, using both fully connected layers and convolution layers. There are a number of architecture changes that can be made. Of particular interest to me are the following:
* Using the labels in order to implement a Conditional VAE (CVAE) as described in the tutorial above. This type of VAE uses extra information (the labels) as part of the encoding and decoding process.
* Using Inverse Autoregressive Flow modules as described in [IVF](https://arxiv.org/abs/1606.04934). This allows us to model more complex distributions by providing an efficient way to have non-diagonal covariance matrices. In particular, this allows for covariance matrices that separate the data rather than give them elliptic 'islands.'
* Using an an [adversarial model](https://arxiv.org/abs/1512.09300) to measure the difference between images. In this notebook, we use 'Binary Categorial Entropy' to measure the difference between each pixel in the image. Measuring the per-pixel difference is a low level way to measure the similarity between images. It gives us a start, but it suffers from issues that we would like to avoid. One apparent issue is that it does not align with a natural interpretation of 'similar' images. This is shown at the end of the notebook.

In [None]:
##Import relevant libraries
#Libraries used to create network
import torch as t
import torch.nn as nn
import torch.nn.functional as F

#Libraries used to retrieve data set and load it
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomAffine, ToTensor

file_path = './models/{0}-{1}.pkl'
use_gpu = True

## Loss functions

The first loss function is defined by how well the image is reconstructed by the network. The more similar each pixel is in the original image is to the same pixel in the reconstructed image, the smaller this loss will be. This loss forces each image reconstruction to be similar to the input image. The `BCEWithLogitsLoss` function is more stable than using a sigmoid activation layer with BCE loss.

The second loss function is the KL Divergence of the encoded mean and variance from the standard normal distribution. This imposes structure on the latent space by pushing the latent representation towards being a standard normal. The further the encoded means and variances are from a $N(0, 1)$ distribution, the more the encoding is penalized. Effectively, we are measuring how much information is lost by using $N(0, 1)$ rather than the true distribution, and we are penalized based on that. To encourage the network to encode information, we set a minimum for the KL divergence. This was noted to improve performance in the IVF paper.

In [None]:
bce = nn.BCEWithLogitsLoss(reduction='sum')

def KL(means, variances, minimum=0.5, epsilon=1e-8):
    '''
    Computes the KL divergence of N(means, variances) with N(0, 1) on a batch,
    then returns the mean of the batch.

    :param means: (torch.Tensor)
    :param variances: (torch.Tensor)
    :param minimum: (float, optional) Minimum value allows for each KL divergence.
    :param epsilon: (float, optional)
    :return: The KL divergence
    '''
    loss = t.sum(variances+means.pow(2)-t.log(variances+epsilon)-1,1)/2.
    minimum = minimum*t.ones_like(loss)
    stack = t.stack([loss, minimum])
    return t.max(stack).mean()

## Utility Layers

This creates the latent layer and a Reshape layer. The latter is exactly as it sounds; it reshapes the input that passes through it.

The former layer is what differentiates this from an autoencoder.

In [None]:
class LatentLayer(nn.Module):
    '''
    Creates a layer that takes an input and outputs a sample of latent variables. This is intended to
    be a connecting layer, and so it has two fully connected layers - one learns the means in the
    latent space, while the other learns the variances. The output for the variance has a softplus
    activator as it must be positive.
    '''
    def __init__(self, input_size, latent_size):
        super(LatentLayer, self).__init__()
        self.layer_mu = nn.Linear(input_size, latent_size)
        self.layer_var = nn.Linear(input_size, latent_size)
        
    def forward(self, x):
        #Compute statistic parameters
        mu = self.layer_mu(x)
        var = F.softplus(self.layer_var(x))
        #Create sample
        samples = mu+t.randn_like(mu)*var.sqrt()
        return samples, mu, var

class ReshapeLayer(nn.Module):
    '''
    Reshapes the input to be the output using view. Shapes are checked
    on forward pass to verify that they are compatible.

    :param view_shape: The shape to cast the input to. Given a batch
        input of shape (n, _) will be cast to (n, view_shape).
    '''
    def __init__(self, view_shape):
        super(ReshapeLayer, self).__init__()
        self.view_shape = view_shape

    def forward(self, x):
        '''
        Reshapes x to initialized shape. Assumes that x.shape[0] is
        the batch size.

        :param x: (torch.tensor)
        :return: (torch.tensor)
        '''
        output_shape = (x.shape[0],) + self.view_shape
        assert self.dimension(x.shape) == self.dimension(output_shape), '{0} and {1} are not compatabile'.format(x.shape, output_shape)
        return x.view(output_shape)

    @staticmethod
    def dimension(shape):
        #Helper function for checking dimensions
        out = 1
        for s in shape:
            out *= s
        return out

## VAE Class

This implements the class itself. In particular, we have an encoder network, a latent layer, and a decoder network. The former 

In [None]:
class VAE(nn.Module):
    '''
    Creates a Variation Autoencoder in Torch.
    '''
    def __init__(self, encoder_network, latent_layer, decoder_network):
        super(VAE, self).__init__()
        
        self.add_module('encoder', encoder_network)
        self.add_module('latent', latent_layer)
        self.add_module('decoder', decoder_network)
        
    def forward(self, x):
        '''
        Passes forward through encoder to create a sample latent
        representation, then decodes that representation.
        
        :param x: (torch.Tensor) 
        :return: (torch.Tensor) Reconstruction of x and it's representation in the latent space
        '''
        #encode
        x = self.encoder(x)
        #Sample from a normal distribution with mean mu and variance var
        z, mu, var = self.latent(x)
        #Decode and return latent variables
        return self.decoder(z), (mu, var)

## Loading the data

Here we use the torch tools to load the MNIST (or FashionMNIST) dataset. The `compose_train` are the transformations we want to apply to the images before they are used by the model. ToTensor converts the PIL image (with values between $[0, 255]$) to a tensor (with values between $[0, 1]$). For both, we use RandomAffine, which does a random rigid movement. For the FashionMNIST data, we also use a RandomHorizontalFlip transformation. These transformations are done to increase the number of examples, which has a positive effect on fitting.

We also call DataLoader, which handles batching and shuffling our data.

In [None]:
batch_size = 128
compose_train = Compose([RandomAffine(1, (0.1, 0.1)), RandomHorizontalFlip(), ToTensor()])

training_set = datasets.FashionMNIST(root='./', train=True, download=True, transform=compose_train)
train_loader = t.utils.data.DataLoader(dataset=training_set, batch_size=batch_size, shuffle=True)

testing_set = datasets.FashionMNIST(root='./', train=False, download=True, transform=ToTensor())

data_name = 'Fashion'

## Building the network

A basic variational autoencoder can be broken down into three modules: the encoder network, the latent layer, and the decoder network. 

To build the networks, we use the Sequential object to build both the encoder and the decoder. To fit them together, we use the latent layer and the VAE class defined earlier. Each network uses SELU activation functions (see [here](https://arxiv.org/abs/1706.02515)).

The sizes of the hidden layers are chosen after some experimentation.

In [None]:
# We name the network here for convenience.
model_name = 'flat_vae_model'
latent_dim = 32

#Build networks
encoder_net = nn.Sequential(
    ReshapeLayer((28**2,)),
    nn.utils.weight_norm(nn.Linear(28**2, 256)),
    nn.SELU(),
    nn.utils.weight_norm(nn.Linear(256, 128)),
    nn.SELU(),
    nn.utils.weight_norm(nn.Linear(128, 64)),
    nn.SELU(),
)

latent_net = LatentLayer(64, latent_dim)

decoder_net = nn.Sequential(
    nn.utils.weight_norm(nn.Linear(latent_dim, 64)),
    nn.SELU(),
    nn.utils.weight_norm(nn.Linear(64, 128)),
    nn.SELU(),
    nn.utils.weight_norm(nn.Linear(128, 256)),
    nn.SELU(),
    nn.utils.weight_norm(nn.Linear(256, 28**2)),
    ReshapeLayer((1, 28, 28)),
)

model = VAE(encoder_net, latent_net, decoder_net)

## Training

The goal of training is two-fold: minimize reconstruction error and minimize information loss as measured by KL divergence. Because of this, we train the model to optimize over these two conditions. The `Adam` algorithm is used to optimize the model.

The model is called in training mode explicitely using `model.train()` and then moved to the gpu using `model.cuda(0)`. To make it run on the cpu, set `use_gpu` to False at the start. If the model is going to be moved to the gpu, it should be done before loading the parameters into the optimizer.

To train the model, we iterate through the data set. We have to send the inputs to the gpu before we can pass it through the model. We keep a running average of the loss to keep us informed on progress. The last three lines of the inner loop are where the training happens in pytorch. `optimizer.zero_grad()` zeros out all derivatives. `loss.backward()` uses autodifferentiation to compute derivatives starting at loss. `optimizer.step()` optimizes the model parameters based on the derivatives from the previous step.

If the model has a pretrained model file available, then that can be loaded by using the second code block instead.

In [None]:
def train_vae(model, data_loader, optimizer, use_gpu=True):
    average_recon = 0
    average_latent = 0
    for n_batch, (x, _) in enumerate(data_loader):
        if use_gpu:
            x = x.cuda()
        
        x_pred, (mu, var) = model(x)
        
        #compute loss
        latent_loss = KL(mu, var)
        recon_loss = bce(x_pred, x)/batch_size
        loss = recon_loss+latent_loss
        
        #Compute average
        average_latent += (latent_loss-average_latent)/(n_batch+1)
        average_recon += (recon_loss-average_recon)/(n_batch+1)

        #Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #Print percent complete
        print(percent_complete.format(batch_size*n_batch/n_data), end='')
        
    return average_recon, average_latent

In [None]:
epochs = 10
learning_rate = 1e-4
model.train()

#Add to gpu before making the optimizer
if use_gpu:
    model.cuda(0)

optimizer = t.optim.Adam(model.parameters(), lr=learning_rate)

epoch_summary_text = '\rEpoch {0}/{1}'
loss_summary_text = '{0} loss: {1:.5}'

percent_complete = '\r{0:.1%}'

n_data = len(training_set)
for epoch in range(epochs):
    recon_loss, latent_loss = train_vae(model, train_loader, optimizer)

    #Create summary loss
    epoch_summary = epoch_summary_text.format(epoch+1, epochs)
    latent_summary = loss_summary_text.format('latent', latent_loss)
    recon_summary = loss_summary_text.format('recon', recon_loss)
    total_summary = loss_summary_text.format('total', latent_loss+recon_loss)
    print(epoch_summary, recon_summary, latent_summary, total_summary, sep=' - ')
    
t.save(model.state_dict(), file_path.format(data_name, model_name))

# Image Generation

## Image Reconstruction

A VAE can be used to reconstruct examples. This is shown below. The images in the top row are the originals, while the images in the bottom row are the reconstructions. This VAE tends to blur examples together. That is, images tend towards their most generic, with colors and edges being blurred. There are a few ways to approach fixing this, which are mentioned above.

In [None]:
import matplotlib.pylab as plt
from random import sample

def image_plot(axis, image, **kwargs):
    axis.imshow(image.view(28,28), cmap='Greys', **kwargs)

In [None]:
#Set up picture
n_images = 11
ax_settings = {'aspect':'equal', 'xticklabels':[], 'yticklabels':[], 'xticks':[], 'yticks':[]}
fig, ax = plt.subplots(2, n_images, subplot_kw=ax_settings, figsize=(n_images*3//2, 3))


model.cpu()
model.eval()
with t.no_grad():
    image_sample = sample(list(testing_set), n_images)
    #Find the mean in latent space
    for index, (image, _) in enumerate(image_sample):
        image = image.view((1,)+image.shape)
        recon_image, _ = model(image)

        image_plot(ax[0, index], image)
        image_plot(ax[1, index], t.sigmoid(recon_image))
        
plt.savefig('./images/{0}-{1}-reconstructions.png'.format(data_name, model_name))

## Latent Sampling

My interest in VAE are mostly tied to the fact that you can use them to generate new images. In this portion, we sample a random variable and use that to create a new image. That is, once a network is trained, we can use it to create new images.

In [None]:
n_images = 15
fig, ax = plt.subplots(1, n_images, subplot_kw=ax_settings, figsize=(n_images*3/2, 3/2))
with t.no_grad():
    for index in range(n_images):
        latent_sample = t.randn(1, latent_dim)
        image = model.decoder(latent_sample)
        image_plot(ax[index], t.sigmoid(image))
        
plt.savefig('./images/{0}-{1}-latent_samples.png'.format(data_name, model_name))

## Distribution of examples in the latent space

Below we can see a scatter plot of the means of each example. The groups tend towards being distributed in an elliptic manner. This is due to our diagonal assumption on variance. Using IVF can improve the distribution of the examples.

Note: The plot is misleading unless the latent dimension is equal to 2. To get around this difficulty, you can use a dimension reduction technique (e.g. t-SNE, umap) to reduce to two dimensions.

In [None]:
fig, ax = plt.subplots(1, 1)
n_images = 0

outputs = t.zeros((10, 2, n_images))
with t.no_grad():
    for index in range(n_images):
        image, label = testing_set[index]
        _, (mean, _) = model(image)
        outputs[label, 0, index] = mean[0][0]
        outputs[label, 1, index] = mean[0][1]
        
    for index, label in enumerate(testing_set.classes):
        ax.scatter(outputs[index, 0, :], outputs[index, 1, :], label=label)
    ax.legend()


## Appendix A: The problem with per-pixel difference measures

![](./images/per-pixel-difference.png)

The image above demonstrates why using a per-pixel similarity measurement is undesirable. We have three images. One is nearly identical (made by shifting every pixel of the original up by 1). Another is similar, but blurred. The last is blank. I would intuitively say the first image is the most similar to the original. However, it is actually the least similar as measured per-pixel. The image that is the most similar is actually the blank image! The problem is that the similarity measurement we are using does not look at any feature of the image. Instead, it pushes us towards recreating each image exactly. This leads to blurred images as we hedge towards generic images rather than sharp images. 

This problem can be alleviated by using an adversarial model. The adversarial model creates a separate network that is tasked with determining the original. This encourages the generator network to produce sharper, more realistic images. In essence, it is measuring the difference between the images by looking at features rather than at pixels. 

The code below will create an image like the above. However, the reconstruction will vary due to the stochastic nature of the model. Because of this, you may end up with a higher or lower loss than image 1.

In [None]:
#Pick example image
image = testing_set[index][0]

#Naming it for convenience
dissimilarity = lambda x, t: F.binary_cross_entropy(x, t, reduction='sum')

#Create a blank image
blank_image = t.zeros_like(image)

#Create reconstruction image
with t.no_grad():
    recon_image, _ = model(image)
    recon_image = t.sigmoid(recon_image[0])

#Create shifted image
shifted_image = t.zeros_like(image)
shifted_image[0,:-1,:] = image[0,1:,:]

#Measure dissimilarities
blank_loss = dissimilarity(image, blank_image)
recon_loss = dissimilarity(image, recon_image)
shifted_loss = dissimilarity(image, shifted_image)

In [None]:
#Create subplot
fig, ax = plt.subplots(1, 4, subplot_kw=ax_settings, figsize=(7, 2))

#Plot and label images.
image_plot(ax[0], image)
ax[0].set_title('Original')

image_plot(ax[1], shifted_image)
ax[1].set_title('Image 1')

image_plot(ax[2], recon_image)
ax[2].set_title('Image 2')

image_plot(ax[3], blank_image)
ax[3].set_title('Image 3')

fig.suptitle('Per-pixel differences for various images')


loss_text = 'Difference between original and image {0}: {1:.5}\n'
shifted_text = loss_text.format(1, shifted_loss)
recon_text = loss_text.format(2, recon_loss)
blank_text = loss_text.format(3, blank_loss)
ax[0].text(0, 32, shifted_text+recon_text+blank_text, verticalalignment='top')