# Problem 2: VAE

The code snippets below implement a VAE for MNIST digits and some visualizations for the results. Check the pdf for instructions of what do to.

## Model definition and optimization

In [None]:
import torch
import torch.nn as nn
import torchvision
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Set hyperparameters of the model and optimization
K = 5
obs_sigma = 0.1
batch_size = 50
# You will want to use a bigger number, but I set it small by default
# so that it is faster to run the code for the first time. Increasing
# numEpoch does not yet count as proper modification.
numEpoch = 5   
lr = 0.001

# MNIST data 
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ])),
  batch_size=batch_size, shuffle=True)

# Prior distribution for latent variables
p_z = torch.distributions.Normal(0., 1.)

# Encoder and decoder specifications
D = 28*28
H = 20
encoder_mu = nn.Sequential(nn.Linear(D,H), nn.ReLU(),
                           nn.Linear(H,H), nn.ReLU(),
                           nn.Linear(H,K,bias=True))
encoder_sigma = nn.Sequential(nn.Linear(D,H), nn.ReLU(),
                              nn.Linear(H,H), nn.ReLU(),
                              nn.Linear(H,K,bias=True))
decoder = nn.Sequential(nn.Linear(K,H), nn.ReLU(),
                        nn.Linear(H,H), nn.ReLU(),
                        nn.Linear(H,D,bias=True))

# Optimize over parameters of all networks
params = list(encoder_mu.parameters()) + list(encoder_sigma.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=lr)

elbos = []
for i in tqdm(range(numEpoch)):
    batches = iter(train_loader)

    epochloss = 0.
    for j in range(len(batches)):
        optimizer.zero_grad()

        # Next batch of samples
        batch_data, batch_targets = next(batches)
        x = batch_data.reshape((batch_size,-1))
    
        # Form parameters of approximation
        mu = encoder_mu(x)
        unconstrained_sigma = encoder_sigma(x)
        sigma = torch.sigmoid(unconstrained_sigma)
        
        # Sample from approximation
        # - rsample() handles reparameterization internally,
        #   so we do not need to do it manually
        # - Note that sample() would not work correctly
        q_z_x = torch.distributions.Normal(mu, sigma)
        z = q_z_x.rsample()

        # Find mean parameters of observed data
        x_mean = decoder(z)
    
        # Leaning objective
        # - Sum over the columns to handle multivariate distributions
        # - Mean over the rows as we want expected loss per data point (not sum)
        logp_x_z = torch.sum(torch.distributions.Normal(x_mean, obs_sigma).log_prob(x), 1)
        KL = torch.sum(q_z_x.log_prob(z) - p_z.log_prob(z), 1)
        loss = - torch.mean(logp_x_z - KL, 0)
        epochloss += loss
    
        loss.backward()
        optimizer.step()
    elbos.append(-epochloss/len(batches))


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing
Done!


  0%|          | 0/5 [00:00<?, ?it/s]

## Plotting functionality

In [None]:
plt.rcParams["figure.figsize"] = (10, 5)
plt.plot(elbos)
plt.xlabel('Epoch')
_ = plt.ylabel('ELBO')

In [None]:
plt.rcParams["figure.figsize"] = (10, 10)
# Note: Uses the values from the last iteration of the algorithm
for sam in range(8):
    plt.subplot(4,4,sam*2+1)
    plt.imshow(x[sam,:].reshape(28,28))

    plt.subplot(4,4,sam*2+2)
    plt.imshow(x_mean[sam,:].detach().reshape(28,28))

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ])),
  batch_size=10000, shuffle=True)

batches = iter(train_loader)
batch_data, batch_targets = next(batches)
x = batch_data.reshape((10000,-1))

mu = encoder_mu(x)
unconstrained_sigma = encoder_sigma(x)

plt.rcParams["figure.figsize"] = (10, 10)
for c in range(10):
    _ = plt.plot(mu.detach()[batch_targets==c,0], mu.detach()[batch_targets==c,1], '.', alpha=0.8)