# First step to implement a Variational AutoEncoder (VAE)

The main goal of this lab is to implement by hand the main steps of a VAE in Pytorch.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import torchvision.utils
import numpy as np

import matplotlib.pyplot as plt

Definition of the most important constant values

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 64         # number of data points in each batch
N_EPOCHS = 10           # times to run the model on complete data
INPUT_DIM = 28 * 28     # size of each input
HIDDEN_DIM = 256        # hidden dimension
LATENT_DIM = 2          # latent vector dimension
lr = 1e-3               # learning rate

## Load the dataset

Import the MNIST dataset.

This dataset is described [here](https://en.wikipedia.org/wiki/MNIST_database)

### Question 1: what is the size of a MNIST image? Is it a color image?

#### Answer:



In [3]:
transforms = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(
    './data',
    train=True,
    download=True,
    transform=transforms)

test_dataset = datasets.MNIST(
    './data',
    train=False,
    download=True,
    transform=transforms
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 85261298.58it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 118181164.71it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 29686396.65it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 21143761.12it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






### Question 2: what is the size of "train_dataset"?

#### Answer:



### Question 3: plot the tenth image of the "train_dataset"? What is its label?

In [4]:
# Fill in this cell


### Question 4: how many batches are composing "train_iterator" in the following cell? What is the size of a batch?

In [5]:
train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_iterator = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [6]:
# Fill in this cell to anwser the question


## Implement the VAE

### Question 5: describe carefully the architecture of the "Encoder". What are "z_mu" and "z_var"?

#### Answer:



In [7]:
class Encoder(nn.Module):
    # This the encoder part of VAE

    def __init__(self, input_dim, hidden_dim, z_dim):
        '''
        Args:
            input_dim: A integer indicating the size of input (in case of MNIST 28 * 28).
            hidden_dim: A integer indicating the size of hidden dimension.
            z_dim: A integer indicating the latent dimension.
        '''
        super().__init__()

        self.linear = nn.Linear(input_dim, hidden_dim)
        self.mu = nn.Linear(hidden_dim, z_dim)
        self.var = nn.Linear(hidden_dim, z_dim)

    def forward(self, x):
        # x is of shape [batch_size, input_dim]

        hidden = F.relu(self.linear(x))
        # hidden is of shape [batch_size, hidden_dim]
        z_mu = self.mu(hidden)
        # z_mu is of shape [batch_size, latent_dim]
        z_var = self.var(hidden)
        # z_var is of shape [batch_size, latent_dim]

        return z_mu, z_var


### Question 6: describe carefully the architecture of the "Decoder". Why are we using "torch.sigmoid"? What is "predicted"?

#### Answer:



In [8]:
class Decoder(nn.Module):
    ''' This the decoder part of VAE

    '''
    def __init__(self, z_dim, hidden_dim, output_dim):
        '''
        Args:
            z_dim: A integer indicating the latent size.
            hidden_dim: A integer indicating the size of hidden dimension.
            output_dim: A integer indicating the output dimension (in case of MNIST it is 28 * 28)
        '''
        super().__init__()

        self.linear = nn.Linear(z_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x is of shape [batch_size, latent_dim]

        hidden = F.relu(self.linear(x))
        # hidden is of shape [batch_size, hidden_dim]

        predicted = torch.sigmoid(self.out(hidden))
        # predicted is of shape [batch_size, output_dim]
        return predicted



### Question 7: complete the following class. You must add the commands to generate the latent variables "z" and to generate the "predicted" output.


#### Answer:



In [9]:
class VAE(nn.Module):
    def __init__(self, enc, dec):
        ''' This the VAE, which takes a encoder and decoder.

        '''
        super().__init__()

        self.enc = enc
        self.dec = dec

    def forward(self, x):

        # Fill in this forward function
        ...



        return predicted, z_mu, z_var


Initialize all the components of the VAE.

In [10]:
# encoder
encoder = Encoder(INPUT_DIM, HIDDEN_DIM, LATENT_DIM)

# decoder
decoder = Decoder(LATENT_DIM, HIDDEN_DIM, INPUT_DIM)

# vae
model = VAE(encoder, decoder).to(device)



Print the number of parameters of the VAE

### Question 8: what is the role of "p.requires_grad" in the following cell?

#### Answer:



In [11]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

Number of parameters: 404244


## Train the VAE

Define the optimizer

In [12]:
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

### Question 9: complete the following function "train" to train the VAE, i.e., you must compute the loss.

In [13]:
def train():
    # set the train mode
    model.train()

    # loss of the epoch
    train_loss = 0

    for i, (x, _) in enumerate(train_iterator):
        # reshape the data into [batch_size, 784]
        x = x.view(-1, 28 * 28)
        x = x.to(device)

        # update the gradients to zero
        optimizer.zero_grad()

        # forward pass
        x_sample, z_mu, z_var = model(x)


        # Fill in this place to compute the two losses: recon_loss and kl_loss
        ...

        # total loss
        # WARNING: Pytorch minimizes the loss. Hence, we compute the opposite of the loss given in the lecture!
        loss = recon_loss + kl_loss

        # backward pass
        loss.backward()
        train_loss += loss.item()

        # update the weights
        optimizer.step()

    return train_loss



### Question 10: what is the role of "model.eval()" in the folllowing "test" function?

#### Answer:



### Question 11: what is the role of "torch.no_grad()" in the folllowing "test" function?

#### Answer:



### Question 12: complete the following function "test" to train the VAE, i.e., you must compute the loss. The answer is the same as for the training step.

#### Answer:



In [14]:
def test():
    # set the evaluation mode
    model.eval()

    # test loss for the data
    test_loss = 0

    # we don't need to track the gradients, since we are not updating the parameters during evaluation / testing
    with torch.no_grad():
        for i, (x, _) in enumerate(test_iterator):
            # reshape the data
            x = x.view(-1, 28 * 28)
            x = x.to(device)

            # forward pass
            x_sample, z_mu, z_var = model(x)

            # Fill in this place to compute the two losses: recon_loss and kl_loss
            ...

            # total loss
            loss = recon_loss + kl_loss
            test_loss += loss.item()

    return test_loss


### Question 13: What is the goal of the following cell? What is the role of "if patience_counter > 3"?

#### Answer:



In [15]:
best_test_loss = float('inf')

for e in range(1,N_EPOCHS+1):

    train_loss = train()
    test_loss = test()

    train_loss /= len(train_dataset)
    test_loss /= len(test_dataset)

    print(f'Epoch {e}/{N_EPOCHS}, Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f}')

    if best_test_loss > test_loss:
        best_test_loss = test_loss
        patience_counter = 1
    else:
        patience_counter += 1

    if patience_counter > 3:
        break


NameError: name 'predicted' is not defined

### Question 14: the following cell shows some images and their approximation with the VAE. What do you think of the result? Test again this cell when LATENT_DIM > 2

#### Answer:



In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.ion() # Turn the interactive mode on.

import torchvision.utils

model.eval()

# Make sure that the image pixels are between 0 and 1
def to_img(x):
    x = x.clamp(0, 1)
    return x

# To plot an image
def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# To visualize the images reconstructed by a given model
def visualise_output(images, model):
    with torch.no_grad():
        images = images.view(-1, 28 * 28)
        images = images.to(device)
        images, _, _ = model(images)
        images = images.cpu()
        images = to_img(images)
        images = images.view(-1, 1, 28, 28)
        np_imagegrid = torchvision.utils.make_grid(images, 10, 7).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
        plt.show()

test_iterator_iter = iter(test_iterator)
images, labels = next(test_iterator_iter)
#print(images.size())

# First visualise the original images
print('Original images')
show_image(torchvision.utils.make_grid(images,10,7))
plt.show()

# Reconstruct and visualise the images using the vae
print('VAE reconstruction:')
visualise_output(images, model)

#### Show 2D Latent Space

### Question 15: the following cell shows some images when the latent space is sampled. It works only when when LATENT_DIM = 2. What do you think of the generated images?

#### Answer:



In [None]:
# load a network that was trained with a 2d latent space
if LATENT_DIM != 2:
    raise Exception('Please change the parameters to two latent dimensions.')

# set the evaluation mode
model.eval()

with torch.no_grad():

    # create a sample grid in 2d latent space
    latent_x = np.linspace(-1.5,1.5,20)
    latent_y = np.linspace(-1.5,1.5,20)
    latents = torch.FloatTensor(len(latent_y), len(latent_x), 2)
    for i, lx in enumerate(latent_x):
        for j, ly in enumerate(latent_y):
            latents[j, i, 0] = lx
            latents[j, i, 1] = ly
    latents = latents.view(-1, 2) # flatten grid into a batch

    # reconstruct images from the latent vectors
    latents = latents.to(device)
    image_recon = model.dec(latents)
    image_recon = image_recon.cpu()
    image_recon = to_img(image_recon)
    image_recon = image_recon.view(-1, 1, 28, 28)

    fig, ax = plt.subplots(figsize=(10, 10))
    show_image(torchvision.utils.make_grid(image_recon.data[:400],20,5))
    plt.xlabel("1st latent dimension")
    plt.ylabel("2nd latent dimension")
    plt.show()