# Variational Autoencoder (VAE) Implementation on CIFAR-10

This notebook implements a complete Variational Autoencoder for image reconstruction using PyTorch. The implementation includes:
- An encoder network that maps input images to a latent space distribution
- A decoder network that reconstructs images from latent representations
- Training loop with loss visualization and model checkpointing
- Utility functions for saving reconstructed images and plotting training curves

## Import Required Libraries

First, let's import all the necessary libraries for our VAE implementation.

In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import wandb
from tqdm import tqdm
from config import Config

## Configuration Class

Define all hyperparameters and configuration settings for the VAE model in a centralized configuration class.

## Encoder Network

The encoder network takes input images and maps them to a latent space. It outputs two vectors:
- **mu**: Mean of the latent distribution
- **logvar**: Log variance of the latent distribution

The encoder uses three convolutional layers with batch normalization and ReLU activations.

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=Config.in_channels, out_channels=Config.encoder_channels[0], kernel_size=Config.kernel_size,
                     stride=2, padding=1, bias=False),
            nn.BatchNorm2d(Config.encoder_channels[0]),
            nn.ReLU(True),
            nn.Conv2d(in_channels=Config.encoder_channels[0], out_channels=Config.encoder_channels[1], kernel_size=Config.kernel_size,
                     stride=2, padding=1, bias=False),
            nn.BatchNorm2d(Config.encoder_channels[1]),
            nn.ReLU(True),
            nn.Conv2d(in_channels=Config.encoder_channels[1], out_channels=Config.encoder_channels[2], kernel_size=Config.kernel_size,
                     stride=2, padding=1),
            nn.BatchNorm2d(Config.encoder_channels[2]),
            nn.ReLU(True),
        )
        
        self.flatten = nn.Flatten()
        self.mu = nn.Linear(in_features=Config.encoder_channels[2] * 4 * 4, out_features=Config.latent_dim)
        self.logvar = nn.Linear(in_features=Config.encoder_channels[2] * 4 * 4, out_features=Config.latent_dim)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        return mu, logvar

## Decoder Network

The decoder network reconstructs images from the latent space representation. It uses:
- A fully connected layer to expand the latent vector to the required size
- Transposed convolutions to upsample the feature maps back to the original image size
- Tanh activation at the output to ensure pixel values are in the range [-1, 1]

In [5]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(in_features=Config.latent_dim, out_features=Config.decoder_channels[0]*4*4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=Config.decoder_channels[0], out_channels=Config.decoder_channels[1],
                             kernel_size=Config.kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(Config.decoder_channels[1]),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=Config.decoder_channels[1], out_channels=Config.decoder_channels[2],
                             kernel_size=Config.kernel_size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(Config.decoder_channels[2]),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=Config.decoder_channels[2], out_channels=Config.out_channels,
                             kernel_size=Config.kernel_size, stride=2, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, z):
        x = self.fc(z)  # [B,latent_dim] -> [B,128*4*4] Converts the input to a 1-D tensor
        x = x.view(-1, Config.decoder_channels[0], 4, 4)  # Reshapes the Input_tensor back to 4-D tensor
        x = self.deconv(x)
        return x

## Complete VAE Model

The VAE combines the encoder and decoder with a reparameterization trick. The reparameterization trick allows us to sample from the latent distribution while maintaining the ability to backpropagate gradients.

**Reparameterization Trick**: Instead of sampling directly from N(μ, σ²), we sample ε ~ N(0,1) and compute z = μ + σ × ε

In [6]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decoder(z)
        return reconstructed, mu, logvar

## Utility Functions

Helper functions for saving reconstructed images during training and plotting the training loss curve.

In [7]:
def save_reconstructed_image(model, dataloader, epoch, max_samples=8):
    """Save reconstructed images during training for visualization"""
    model.eval()
    os.makedirs(Config.reconstruction_save_path, exist_ok=True)
    
    with torch.no_grad():
        images = next(iter(dataloader))[0][:max_samples].to(Config.device)
        recon, _, _ = model(images)
        
        # Combine them into GRIDS
        grid = make_grid(recon, nrow=max_samples, normalize=True)
        file_path = os.path.join(Config.reconstruction_save_path, f"recon_epoch:{epoch}.png")
        save_image(grid, file_path)
    
    model.train()
    return grid

def plot_training_curve(loss_list, save_path="saves/vae_loss_plot.png"):
    """Plot and save the training loss curve"""
    plt.figure(figsize=(8, 5))
    plt.plot(loss_list, label="Epoch Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("VAE Training Loss")
    plt.legend()
    
    if save_path:
        plt.savefig(save_path)
    plt.show()

## Setup Training Environment

Initialize the device, set random seed for reproducibility, and set up Weights & Biases logging.

In [8]:
# Set Device & SEED
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(Config.seed)
print(f"Using device: {device}")

# Initialize WandB & Model
wandb.init(project=Config.wandb_project,
           name=f"{Config.wandb_run_name}_2",
           config={k: v for k, v in Config.__dict__.items() if not k.startswith("__")})

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)
reconstruction_criterion = nn.MSELoss(reduction="sum")

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Using device: cpu


wandb: Currently logged in as: atharv3105 (atharv3105-dr-a-p-j-abdul-kalam-technical-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Model parameters: 1,120,259


## Loss Function

The VAE loss function consists of two components:
1. **Reconstruction Loss**: Measures how well the decoder reconstructs the input (using MSE)
2. **KL Divergence**: Regularizes the latent space to follow a standard normal distribution

Total Loss = Reconstruction Loss + KL Divergence

In [9]:
def loss_fn(recon_x, x, mu, logvar):
    """Calculate VAE loss with reconstruction and KL divergence components"""
    recon_loss = reconstruction_criterion(recon_x, x)
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    total_loss = recon_loss + kl_divergence
    return total_loss, recon_loss, kl_divergence

## Data Loading

Load the CIFAR-10 dataset with appropriate transformations. We normalize the images to [-1, 1] range to match the Tanh output of the decoder.

In [None]:
# Load Dataset & Make DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize to [-1, 1]
])

train_dataset = datasets.CIFAR10(root=Config.dataset_path, train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Number of batches: {len(train_loader)}")

## Training Loop

Main training loop that:
1. Performs forward pass through the VAE
2. Calculates reconstruction and KL divergence losses
3. Performs backpropagation and parameter updates
4. Logs metrics to Weights & Biases
5. Saves reconstructed images and model checkpoints at specified intervals

In [None]:
# Create directories for saving
os.makedirs("saves/reconstructions", exist_ok=True)
os.makedirs("saves/checkpoints", exist_ok=True)

# Training Loop
total_loss_list = []
model.train()

for epoch in range(1, Config.num_epochs + 1):
    epoch_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch: [{epoch}/{Config.num_epochs}]", leave=False)
    
    for x, _ in loop:
        x = x.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        recon_x, mu, logvar = model(x)
        
        # Calculate loss
        loss, recon_loss, kl_loss = loss_fn(recon_x, x, mu, logvar)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        loop.set_postfix(loss=loss.item() / x.size(0))
        
        # Log batch metrics
        wandb.log({
            "Batch Loss": loss.item() / x.size(0),
            "Reconstruction Loss": recon_loss.item() / x.size(0),
            "KL Loss": kl_loss.item() / x.size(0),
        })
    
    # Calculate average epoch loss
    avg_loss = epoch_loss / len(train_loader.dataset)
    total_loss_list.append(avg_loss)
    wandb.log({"Epoch Loss": avg_loss, "Epoch": epoch})
    
    # Save Reconstruction
    if epoch % Config.save_reconstruction_interval == 0:
        model.eval()
        with torch.no_grad():
            x, _ = next(iter(train_loader))
            x = x.to(device)[:8]
            recon_x, _, _ = model(x)
            compare = torch.cat([x, recon_x])
            save_image(compare.cpu(), f"saves/reconstructions/recon_epoch:{epoch}.png", nrow=8)
        model.train()
    
    # Save Checkpoint
    if epoch % 25 == 0:
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optim_state_dict": optimizer.state_dict(),
        }
        torch.save(checkpoint, f"saves/checkpoints/vae_epoch:{epoch}.pt")

print("Training completed!")

## Visualize Training Results

Plot the training loss curve and save the final results.

In [None]:
# Plot training curve
plot_training_curve(total_loss_list)
wandb.save("saves/vae_training_loss.png")

print(f"Final training loss: {total_loss_list[-1]:.6f}")
print("Training loss curve saved to saves/vae_training_loss.png")

## Generate Final Reconstructions

Generate and save final reconstructions to visualize the model's performance.

In [None]:
# Generate final reconstructions
model.eval()
with torch.no_grad():
    # Get a batch of test images
    test_images, _ = next(iter(train_loader))
    test_images = test_images[:8].to(device)
    
    # Generate reconstructions
    reconstructions, mu, logvar = model(test_images)
    
    # Save comparison
    comparison = torch.cat([test_images, reconstructions])
    save_image(comparison.cpu(), "saves/final_reconstruction_comparison.png", nrow=8)
    
    print("Final reconstructions saved to saves/final_reconstruction_comparison.png")
    print("Top row: Original images")
    print("Bottom row: Reconstructed images")

# Close wandb run
wandb.finish()