# Variational Autoencoder

This notebook demonstrates the workflow for building and training a Variational Autoencoder (VAE) using PyTorch with the MNIST dataset. This notebook will walk through the key steps, including data preparation, model definition, training, and evaluation, providing a comprehensive guide to implementing VAEs for unsupervised learning tasks on handwritten digit images.

***

## Loading Libraries

Library | Version | Channel
--- | --- | ---
NumPy | 1.26.4 | default
PyTorch | 2.2.2 | pytorch
Torchvision | 0.17.2 | pytorch
Tensorboard | / | conda-forge

In [21]:
# Built-in libraries
from dataclasses import dataclass
from datetime import datetime

# Third-party libraries
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import v2

## Loading Data

The [MNIST](http://yann.lecun.com/exdb/mnist/) dataset is a widely-used benchmark in machine learning, consisting of 70,000 images of handwritten digits from 0 to 9. Each image is a 28x28 grayscale pixel grid. Due to its simplicity and well-structured format, MNIST serves as an excellent starting point for developing and testing machine learning models, particularly in the field of image recognition and classification.

### Hyperparameters

In [22]:
# Model size
input_layer = 784
layer_one = 784 // 2
layer_two = 784 // 4
layer_three = 784 // 8
latent_space = 4
# Dataloaders
batch_size = 128
# Optimizer
learning_rate = 1e-3
weight_decay = 1e-2
# Training
folds = 5
epochs = 15

## The Model

The **Variational Autoencoder (VAE)** was firstly introduced by Kingma and Welling in 2013 [[1]](https://arxiv.org/abs/1312.6114). A VAE is a generative model comprising an **encoder**, which maps the input data to a latent space. This component is also referred to as a recognition model, as it is responsible for recognising important patterns within the data. The other component of the model is a **decoder**, which generates a reconstructed representation of the input data. This is why the decoder is also referred to as the generative model [[2]](https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/).

## Data Preparation

In [25]:
transform = v2.Compose([
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(lambda x: x.view(-1) - 0.5),
])

# Download and load the training data
train_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=True, 
    transform=transform,
)
# Download and load the test data
test_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=False, 
    transform=transform,
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=batch_size, 
    shuffle=False,
)

## Training

In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VariationalAutoencoder(input_layer, layer_one, layer_two, layer_three, latent_space).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
writer = SummaryWriter(f'runs/mnist/vae_{datetime.now().strftime("%Y%m%d-%H%M%S")}')

In [27]:
def train(model, dataloader, optimizer, prev_updates, writer=None):
    """
    Trains the model on the given data.
    
    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        loss_fn: The loss function.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode
    
    for batch_idx, (data, target) in enumerate(dataloader):
        n_upd = prev_updates + batch_idx
        
        data = data.to(device)
        
        optimizer.zero_grad()  # Zero the gradients
        
        output = model(data)  # Forward pass
        loss = output.loss
        
        loss.backward()
        
        if n_upd % 100 == 0:
            # Calculate and log gradient norms
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
        
            print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f} (Recon: {output.loss_recon.item():.4f}, KL: {output.loss_kl.item():.4f}) Grad: {total_norm:.4f}')

            if writer is not None:
                global_step = n_upd
                writer.add_scalar('Loss/Train', loss.item(), global_step)
                writer.add_scalar('Loss/Train/BCE', output.loss_recon.item(), global_step)
                writer.add_scalar('Loss/Train/KLD', output.loss_kl.item(), global_step)
                writer.add_scalar('GradNorm/Train', total_norm, global_step)
            
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)    
        
        optimizer.step()  # Update the model parameters
        
    return prev_updates + len(dataloader)

In [30]:
def test(model, dataloader, cur_step, writer=None):
    """
    Tests the model on the given data.
    
    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
        cur_step (int): The current step.
        writer: The TensorBoard writer.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    test_recon_loss = 0
    test_kl_loss = 0
    
    with torch.no_grad():
        for data, target in dataloader:
            data = data.to(device)
            data = data.view(data.size(0), -1)  # Flatten the data
            
            output = model(data, compute_loss=True)  # Forward pass
            
            test_loss += output.loss.item()
            test_recon_loss += output.loss_recon.item()
            test_kl_loss += output.loss_kl.item()
            
    test_loss /= len(dataloader)
    test_recon_loss /= len(dataloader)
    test_kl_loss /= len(dataloader)
    print(f'====> Test set loss: {test_loss:.4f} (BCE: {test_recon_loss:.4f}, KLD: {test_kl_loss:.4f})')
    
    if writer is not None:
        writer.add_scalar('Loss/Test', test_loss, global_step=cur_step)
        writer.add_scalar('Loss/Test/BCE', output.loss_recon.item(), global_step=cur_step)
        writer.add_scalar('Loss/Test/KLD', output.loss_kl.item(), global_step=cur_step)
        
        # Log reconstructions
        writer.add_images('Test/Reconstructions', output.x_recon.view(-1, 1, 28, 28), global_step=cur_step)
        writer.add_images('Test/Originals', data.view(-1, 1, 28, 28), global_step=cur_step)
        
        # Log random samples from the latent space
        z = torch.randn(16, latent_space).to(device)
        samples = model.decode(z)
        writer.add_images('Test/Samples', samples.view(-1, 1, 28, 28), global_step=cur_step)

In [31]:
prev_updates = 0
for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer)
    test(model, test_loader, prev_updates, writer=writer)

Epoch 1/15
Step 0 (N samples: 0), Loss: 164.3871 (Recon: 159.9528, KL: 4.4342) Grad: 29.7743
Step 100 (N samples: 12,800), Loss: 156.7009 (Recon: 150.5463, KL: 6.1546) Grad: 23.8390
Step 200 (N samples: 25,600), Loss: 166.9792 (Recon: 160.3606, KL: 6.6186) Grad: 64.0568
Step 300 (N samples: 38,400), Loss: 152.1218 (Recon: 145.3236, KL: 6.7981) Grad: 39.0422
Step 400 (N samples: 51,200), Loss: 145.5018 (Recon: 138.7176, KL: 6.7842) Grad: 30.9859
====> Test set loss: 146.6436 (BCE: 139.5255, KLD: 7.1181)
Epoch 2/15
Step 500 (N samples: 64,000), Loss: 147.8998 (Recon: 140.7776, KL: 7.1221) Grad: 38.2662
Step 600 (N samples: 76,800), Loss: 145.5654 (Recon: 137.3705, KL: 8.1949) Grad: 45.3294
Step 700 (N samples: 89,600), Loss: 140.7345 (Recon: 132.3809, KL: 8.3535) Grad: 31.1977
Step 800 (N samples: 102,400), Loss: 131.4105 (Recon: 122.7430, KL: 8.6675) Grad: 34.2402
Step 900 (N samples: 115,200), Loss: 136.9081 (Recon: 127.9592, KL: 8.9490) Grad: 44.5080
====> Test set loss: 134.6257 (BCE