# Variational Autoencoders

Learning Outcomes:
- Train convolutional VAEs using the reparametrization trick.
- Generete new, unseen data by sampling from the latent space.
- Illustrate interpolation between different images thanks to latent representations.
- Visualize the effect of different weights on the regularization term on the learnt latent space.

## Library Imports

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## The MNIST dataset
The following coad loads the MNIST dataset and builds the necessary dataloaders for training.

In [3]:
from torchvision import datasets, transforms
batch_size = 128

# Data loading
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='../../data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='../../data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


## 1. Define the VAE class
In this section we will define the VAE class that we will train and use for image generation. We make the choice of training a Convolutional VAE, with the following architecture (once more, we leave the number of hidden dimensions $p$ as a free parameter):
- **The Encoder:** The encoder will consist of the following layers:
    - Convolution layer with 32 filters, a kernel size of 4, stride 2 and padding 1.
    - BatchNorm layer keeping the same number of features
    - ReLu activation
    - Convolution layer with 64 filters, a kernel size of 4, stride 2 and padding 1.
    - BatchNorm layer keeping the same number of features
    - ReLu activation
    - Convolution layer with 128 filters, a kernel size of 3, stride 2 and padding 1.
    - Batchnorm layer keeping the same number of features
    - ReLU layer

- **The Latent Space:** The encoder outputs are converted into the mean vector $\mu$ and logarithm of the variance vector $\log\sigma^2$, via two paraller fully connected layers. We will need to define:
    - A FC layer to map the output of the encoder $E(x)$ to the mean vector $\mu(x)$.
    - A FC layer to map the output of the encoder $E(x)$ to the log-variance vector $\log\sigma^2(x)$.
    - A FC layer to map the sampled hidden stacte $z(x)\sim\mathcal{N}(\mu(x),Diag(\sigma(x)))$ to the decoder input.

- **The Decoder.** The decoder will consist of the following layers:
    - Deconvolution layer with 64 filters, a kernel size of 3, stride 2 and padding 1.
    - BatchNorm layer keeping the same number of features
    - ReLu activation
    - Deconvolution layer with 32 filters, a kernel size of 4, stride 2 and padding 1.
    - BatchNorm layer keeping the same number of features
    - ReLu activation
    - Deconvolution layer with 1 filter, a kernel size of 4, stride 2 and padding 1.
    - Sigmoid layer

In [17]:
class ConvVAE(nn.Module):
    def __init__(self, latent_dim=10):
        super(ConvVAE, self).__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # Output: (32, 14, 14)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # TODO: Add the missing encoder layers
        )

        # Fully connected layers for mean and log variance
        self.fc_mu = # TODO: Add the missing FC layer
        self.fc_logvar = # TODO: Add the missing FC layer
        self.fc_decode = # TODO: Add the missing FC layer
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),  # Output: (64, 8, 8)
            # TODO: Add the missing decoder layers
        )
    
    def encode(self, x):
        x = # TODO: Encode the input image
        x = # Flatten the output of the convolutional layers
        mu = # TODO: Apply the corresponding FC layer
        logvar = # TODO: Apply the corresponding FC layer
        return mu, logvar
    

    def sample(self, mu, logvar):
        # TODO: Sample z from the Gaussian with the given mu and logvar
        # by using the reparameterization trick: z = mu + sigma * epsilon
    
    # TODO: Implement the decode function


    # TODO: Implement the forward function
    # by combining the encode, sample and decode functions


In [None]:
# %load solutions/conv_vae.py

## 2. Define the loss function

In [6]:
# TODO: Implement the loss function

In [None]:
# %load solutions/vae_loss.py

## 3. Train the VAE
In order to be able to visualize the latent space of the VAE, we will choose a latent dimension equal to 2.

In [None]:
# Hyperparameters
latent_dim = 2
learning_rate = 1e-3
epochs = 30
beta = 1

# TODO: Initialize the VAE model and the Adam optimizer
# and move the model to the device

# TODO: Train the model for the given number of epochs
# At the end of each epoch, print the training loss

In [None]:
# %load solutions/train_vae.py

## 4. Visualize the results
We first check if the vae model has learnt meaningful features, by plotting a bunch of images from the test set along with their respective reconstructions.

We have already define a function called `image_comparison` in the previous noteboo, that does exactly what we want. We can either copy paste it below, or better yet, create a file in the current folder called `utils.py`, copy-paste the function there, along with all the necessary library imports, and then import the `image_comparison` function in the cell below.

In [None]:
# TODO: Define or import the image_comparison function

# Select a batch of images from the test dataset
random_images = next(iter(test_loader))

# Get the reconstructions of the selected images
recons, _, _ = vae(random_images[0].to(device))

# Reshape the images for plotting
random_images = random_images[0].cpu().numpy().squeeze()
recons = recons.detach().cpu().numpy().squeeze()

# Plot the original images and their reconstructions
image_comparison(random_images, recons)

## 5. Image generation
The puropose of this section is to generate new images that look like MNIST digits. In order to do so, we follow the steps below:
- Sample $z$ from a $\mathcal{N}(0, I)$ distribution ($I$ being the identity matrix of size $p$).
- Decode $z$ using the decoder of the VAE to generate a new image.

**Question.** Why are we sampling $z$ from a $\mathcal{N}(0,I)$ distribution? What happened to the learnt mean and variance?

In [None]:
def generate_sample(num_samples=10):
    vae.eval()
    with torch.no_grad():
        # TODO: Sample random latent vectors
        samples = # TODO: Decode the latent vectors
        samples = samples.cpu().view(num_samples, 1, 28, 28) # Reshape the samples

        fig, ax = plt.subplots(1, num_samples, figsize=(15, 2))
        for i in range(num_samples):
            ax[i].imshow(samples[i].squeeze(0), cmap='gray')
            ax[i].axis('off')
        plt.show()

generate_sample()

In [None]:
# %load solutions/sample_gen.py

**Exercise.** As a follow-up exercise, you can check how if the quality of the generated samples improves when using a VAE trained with a larger hidden dimension.

## 6. Interpolation between Images
The objective of this section is to visualize the difference between the space of latent representations and the (original) pixel space. In order to do so, we will perform *image interpolation*, i.e., we will take two random images $x_2$ and $x_2$ from the test set, and interpolate between them: for a given number of interpolation steps $n$, we have:
- In pixel space, the interpolated image $x_t$ at step $t=0,\dots,n$ is given by taking, for each pixel, the linear interpolation 
$$\frac{n-t}{n}x_1 + \frac{t}{n}x_2.$$
- In the latent space, the interpolated image $x_t$ at step $t=0,\dots,n$ is given by first computing the linear interpolation $z_t$ between the encodings $z_1$ of $x_1$ and $z_2$ of $x_2$, and then decoding $z_t$.

In [None]:
# TODO: Implement the interpolate_pixel_space function
# the function should take two images as input and the numer of interpolation steps
# and plot the interpolated images in a single row

x1, x2 = test_dataset[3][0], test_dataset[2][0]
interpolate_pixel_space(x1, x2)

In [None]:
# %load solutions/pixel_interp.py

In [None]:
# TODO: Implement the interpolate_latent_space function
# the function should take two images as input and the numer of interpolation steps
# and plot the interpolated images in a single row

x1, x2 = test_dataset[3][0], test_dataset[2][0]
interpolate_latent_space(x1, x2)

In [None]:
# %load solutions/latent_interp.py

**Questions.** 
1. How is the interpolation process any different?
2. Are the first and last images the same for both interpolation processes? Why?

## 7. Visualizing the latent space
The objective of this section is to visualize the latent space and to see how it changes according to which term in the loss function we give more weight to.

In [58]:
# Rewrite loss function to return BCE and KLD separately as well
def loss_function(recon_x, x, mu, logvar, beta=1):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD, BCE, KLD

In [None]:
# Hyperparameters
latent_dims = 2
batch_size = 128
num_epochs = 10
learning_rate = 1e-3
kl_weights = [1, 10, 100]  # Different weights for the KL divergence term

# Training and plotting function
def train_and_plot(kl_weight):
    model = ConvVAE(latent_dims).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.train()
    train_losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        bce_loss = 0
        kld_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            x_recon, mu, logvar = model(data)
            loss, bce, kld = loss_function(x_recon, data, mu, logvar, kl_weight)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            bce_loss += bce.item()
            kld_loss += kld.item()
        
        average_loss = epoch_loss / len(train_loader.dataset)
        average_bce = bce_loss / len(train_loader.dataset)
        average_kld = kld_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}: Average Loss: {average_loss:.4f}, BCE: {average_bce:.4f}, KLD: {average_kld:.4f}')
    
    # Plot latent space
    plot_latent_space(model, kl_weight)

# Function to plot latent space
def plot_latent_space(model, kl_weight):
    model.eval()
    with torch.no_grad():
        test_loader = DataLoader(dataset=test_dataset, batch_size=10000, shuffle=False)
        data, labels = next(iter(test_loader))
        data = data.to(device)
        mu, logvar = model.encode(data)
        z = mu  # For visualization, we use the mean
        z = z.cpu().numpy()
        labels = labels.numpy()
        
        plt.figure(figsize=(8,6))
        scatter = plt.scatter(z[:, 0], z[:, 1], c=labels, cmap='tab10', alpha=0.7)
        plt.colorbar(scatter, ticks=range(10))
        plt.clim(-0.5, 9.5)
        plt.title(f'Latent Space with KL Weight = {kl_weight}')
        plt.xlabel('Z1')
        plt.ylabel('Z2')
        plt.show()

# Run training and plotting for different KL weights
for kl_weight in kl_weights:
    print(f'\nTraining VAE with KL Weight = {kl_weight}')
    train_and_plot(kl_weight)


**Question.** 
1. Describe is the effect of the KL weight $\beta$ on the latent space.
2. Explain why the described effect happens, and link it with the objective of each of the terms in the loss function.
