<a href="https://colab.research.google.com/github/RDGopal/IB9AU-2026/blob/main/SD2_Variational_Autoencoder_(VAE)_Illustration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Understanding Variational Autoencoders (VAEs)

This notebook demonstrates a Variational Autoencoder (VAE) applied to the MNIST dataset. Before diving into the code, let's clarify what a VAE is and how it relates to a standard Autoencoder.

### Autoencoders (AE)

A traditional Autoencoder is a neural network designed to learn an efficient, compressed representation (encoding) of input data. It consists of two main parts:
1.  **Encoder**: Maps the input data to a lower-dimensional latent space representation.
2.  **Decoder**: Reconstructs the original input data from the latent space representation.

The goal of an AE is to minimize the reconstruction error, meaning the output should be as similar as possible to the input. While AEs can learn useful representations, their latent space is not continuous or easily sampled, making it difficult to generate new, meaningful data by sampling from this space.

### Variational Autoencoders (VAE)

A Variational Autoencoder addresses the limitations of standard Autoencoders by introducing a probabilistic approach to the latent space. Instead of mapping an input to a fixed point in the latent space, a VAE maps it to a **probability distribution** (specifically, a Gaussian distribution) in the latent space.

Key differences and features of VAEs:
*   **Probabilistic Latent Space**: The encoder outputs two vectors for each input: a mean vector (`z_mean`) and a logarithm of variance vector (`z_log_var`). These define the parameters of a Gaussian distribution.
*   **Sampling**: Instead of directly using `z_mean` as the latent representation, a VAE samples from the distribution defined by `z_mean` and `z_log_var`. This sampling process introduces stochasticity, making the latent space smoother and more continuous.
*   **Loss Function**: A VAE's loss function has two components:
    1.  **Reconstruction Loss**: Measures how well the decoder reconstructs the input from the sampled latent vector (similar to AE).
    2.  **KL Divergence Loss**: This is a regularization term that forces the learned latent distribution for each input to be close to a standard normal distribution (mean 0, variance 1). This ensures that the latent space is well-structured and continuous, allowing for meaningful sampling.
*   **Generative Capability**: Because the latent space is structured and continuous, we can sample arbitrary points from the standard normal distribution, pass them through the decoder, and generate new, coherent data that resembles the training data.

In essence, a VAE learns a compressed, continuous, and disentangled latent representation of the data, which not only allows for efficient reconstruction but also for the generation of novel data points by sampling from a simple prior distribution in the latent space.

### 1. Setup and Data Loading

The following initializes the necessary libraries, defines global configuration parameters for our VAE model, and loads the MNIST dataset. The MNIST dataset, consisting of handwritten digits, is a classic choice for demonstrating generative models. We'll preprocess the images by normalizing pixel values and flattening them into one-dimensional vectors suitable for our dense network.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

print("PyTorch Version:", torch.__version__)

# --- Configuration ---
LATENT_DIM_VAE = 2   # Use 2 for easy 2D visualization, or e.g., 32 for better quality
IMAGE_SIZE = 784    # 28x28 pixels flattened
EPOCHS = 10         # Number of training epochs (keep low for quick demo)
BATCH_SIZE = 128    # Number of samples per gradient update
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")

# --- 1. Load and Preprocess MNIST Data ---
print("Loading MNIST data...")

# Define a transform to flatten the images and normalize them
transform = transforms.Compose([
    transforms.ToTensor(), # Converts a PIL Image or numpy.ndarray to a FloatTensor of shape (C, H, W)
    transforms.Lambda(lambda x: x.view(-1)), # Flatten 28x28 to 784
    # MNIST pixel values are already 0-255, ToTensor scales them to 0-1. No explicit normalization needed beyond that
])

# Load MNIST training dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Load MNIST test dataset
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Training data samples: {len(train_dataset)}")
print(f"Test data samples: {len(test_dataset)}")

# Verify data shape
for images, labels in train_loader:
    print(f"Shape of images batch: {images.shape}") # Expected: (BATCH_SIZE, IMAGE_SIZE)
    print(f"Shape of labels batch: {labels.shape}") # Expected: (BATCH_SIZE)
    break


### 2. Defining VAE Components

Here, we define the core components of our Variational Autoencoder:


In [None]:
import torch.nn as nn
import torch.nn.functional as F

# --- Encoder ---
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        # x is assumed to be flattened (batch_size, input_dim)
        h = F.relu(self.fc1(x))
        z_mean = self.fc_mean(h)
        z_log_var = self.fc_log_var(h)

        # Reparameterization trick
        std = torch.exp(0.5 * z_log_var)
        epsilon = torch.randn_like(std) # Sample from standard normal
        z = z_mean + epsilon * std

        return z_mean, z_log_var, z

# --- Decoder ---
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        reconstruction = torch.sigmoid(self.fc2(h)) # Sigmoid for [0,1] pixel values
        return reconstruction

# --- VAE Model Class ---
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        z_mean, z_log_var, z = self.encoder(x)
        reconstruction = self.decoder(z)
        return reconstruction, z_mean, z_log_var


hidden_dim = 128 # Example hidden dimension for both encoder and decoder

# Create instances for summary output, similar to Keras summary
encoder_temp = Encoder(IMAGE_SIZE, hidden_dim, LATENT_DIM_VAE).to(DEVICE)
decoder_temp = Decoder(LATENT_DIM_VAE, hidden_dim, IMAGE_SIZE).to(DEVICE)

print("\nEncoder Architecture:")
print(encoder_temp)

print("\nDecoder Architecture:")
print(decoder_temp)


### 3. Build and Train the VAE

In this section, we instantiate our encoder and decoder, then combine them into the full VAE model. We compile the model with an Adam optimizer, and then train it using the preprocessed MNIST training data. During training, the VAE learns to both reconstruct the input images accurately and to maintain a well-structured, continuous latent space.

In [None]:
print("\n--- Building and Compiling VAE (PyTorch) ---")
# Hidden dimension is used in the Encoder and Decoder for the intermediate layer
hidden_dim = 128

# Instantiate Encoder, Decoder, and VAE
encoder = Encoder(IMAGE_SIZE, hidden_dim, LATENT_DIM_VAE).to(DEVICE)
decoder = Decoder(LATENT_DIM_VAE, hidden_dim, IMAGE_SIZE).to(DEVICE)
vae = VAE(encoder, decoder).to(DEVICE)

# Define Optimizer
optimizer = optim.Adam(vae.parameters())

# --- VAE Loss Function ---
def vae_loss(reconstruction, x, z_mean, z_log_var):
    # Reconstruction loss (Binary Cross-Entropy)
    # F.binary_cross_entropy_with_logits is often used for raw logits, but our decoder output is sigmoid (0-1)
    # So we use F.binary_cross_entropy
    reconstruction_loss = F.binary_cross_entropy(reconstruction, x, reduction='sum')

    # KL Divergence Loss
    # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    # log(sigma^2) is z_log_var
    kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())

    # Total VAE loss
    total_loss = reconstruction_loss + kl_loss
    return total_loss, reconstruction_loss, kl_loss

# --- Training Loop ---
print("\n--- Training Variational Autoencoder (PyTorch) ---")
history_vae = {'loss': [], 'reconstruction_loss': [], 'kl_loss': [], 'val_loss': [], 'val_reconstruction_loss': [], 'val_kl_loss': []}

for epoch in range(EPOCHS):
    vae.train() # Set model to training mode
    train_loss = 0.0
    train_recon_loss = 0.0
    train_kl_loss = 0.0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(DEVICE) # Move data to appropriate device
        optimizer.zero_grad() # Zero the gradients

        reconstruction, z_mean, z_log_var = vae(data)
        total_batch_loss, recon_batch_loss, kl_batch_loss = vae_loss(reconstruction, data, z_mean, z_log_var)

        total_batch_loss.backward() # Backpropagate
        optimizer.step() # Update weights

        train_loss += total_batch_loss.item()
        train_recon_loss += recon_batch_loss.item()
        train_kl_loss += kl_batch_loss.item()

    avg_train_loss = train_loss / len(train_dataset)
    avg_train_recon_loss = train_recon_loss / len(train_dataset)
    avg_train_kl_loss = train_kl_loss / len(train_dataset)

    history_vae['loss'].append(avg_train_loss)
    history_vae['reconstruction_loss'].append(avg_train_recon_loss)
    history_vae['kl_loss'].append(avg_train_kl_loss)

    # --- Evaluation Loop ---
    vae.eval() # Set model to evaluation mode
    val_loss = 0.0
    val_recon_loss = 0.0
    val_kl_loss = 0.0
    with torch.no_grad(): # Disable gradient calculations during evaluation
        for data, _ in test_loader:
            data = data.to(DEVICE)
            reconstruction, z_mean, z_log_var = vae(data)
            total_batch_loss, recon_batch_loss, kl_batch_loss = vae_loss(reconstruction, data, z_mean, z_log_var)
            val_loss += total_batch_loss.item()
            val_recon_loss += recon_batch_loss.item()
            val_kl_loss += kl_batch_loss.item()

    avg_val_loss = val_loss / len(test_dataset)
    avg_val_recon_loss = val_recon_loss / len(test_dataset)
    avg_val_kl_loss = val_kl_loss / len(test_dataset)

    history_vae['val_loss'].append(avg_val_loss)
    history_vae['val_reconstruction_loss'].append(avg_val_recon_loss)
    history_vae['val_kl_loss'].append(avg_val_kl_loss)

    print(f"Epoch {epoch+1}/{EPOCHS}, "
          f"Train Loss: {avg_train_loss:.4f}, "
          f"Train Recon Loss: {avg_train_recon_loss:.4f}, "
          f"Train KL Loss: {avg_train_kl_loss:.4f}, "
          f"Val Loss: {avg_val_loss:.4f}, "
          f"Val Recon Loss: {avg_val_recon_loss:.4f}, "
          f"Val KL Loss: {avg_val_kl_loss:.4f}")

print("Training complete.")

### 4. Visualize VAE Reconstructions

After training, one of the first ways to evaluate a VAE is to see how well it can reconstruct images it has seen before. This function takes a few test images, passes them through the trained VAE (encoder then decoder), and displays the original alongside their reconstructed counterparts. This helps us understand the quality of the learned representations.

In [None]:
print("\n--- Visualizing VAE Reconstructions (PyTorch) ---")

# Function to plot reconstructions
def plot_reconstructions_vae(vae_model, test_loader, n=10):
    """Plots original and VAE reconstructed images using the PyTorch VAE model."""
    vae_model.eval() # Set model to evaluation mode
    with torch.no_grad(): # Disable gradient calculations
        # Get a batch of test data
        data_iter = iter(test_loader)
        images, _ = next(data_iter)

        # Ensure we have at least n images
        if images.shape[0] < n:
            print(f"Not enough images in the batch ({images.shape[0]}) to plot {n}. Using available images.")
            n = images.shape[0]

        # Move images to the device and get reconstructions
        images = images.to(DEVICE)
        reconstructed_images, _, _ = vae_model(images[:n])

        # Move images back to CPU and convert to numpy for plotting
        original_imgs_np = images[:n].cpu().numpy()
        reconstructed_imgs_np = reconstructed_images.cpu().numpy()

        plt.figure(figsize=(20, 4))
        plt.suptitle("VAE: Original vs Reconstructed Images (PyTorch)", fontsize=16)
        for i in range(n):
            # Display original
            ax = plt.subplot(2, n, i + 1)
            plt.imshow(original_imgs_np[i].reshape(28, 28), cmap='gray')
            plt.title("Original")
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            # Display reconstruction
            ax = plt.subplot(2, n, i + 1 + n)
            plt.imshow(reconstructed_imgs_np[i].reshape(28, 28), cmap='gray')
            plt.title("Reconstructed")
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
        plt.show()

# Plot VAE Reconstructions
plot_reconstructions_vae(vae, test_loader, n=10)


### 5. Generate New Images from Latent Space

A key advantage of VAEs over traditional Autoencoders is their ability to generate entirely new data. Because the VAE's latent space is structured to resemble a standard normal distribution, we can sample random points from this distribution. When these random latent vectors are fed into the decoder, they produce novel, realistic-looking images that were not part of the training data. This demonstrates the VAE's generative power.

In [None]:
# --- VAE Specific Visualizations ---

# 1. Generate new digits by sampling from the latent space prior
def plot_generated_images(decoder_model, n=15, latent_dim=LATENT_DIM_VAE):
    """Generates images by sampling from the standard normal prior."""
    decoder_model.eval() # Set decoder to evaluation mode
    with torch.no_grad():
        random_latent_vectors = torch.randn(n, latent_dim).to(DEVICE)
        generated_images = decoder_model(random_latent_vectors)

        # Move images back to CPU and convert to numpy for plotting
        generated_images_np = generated_images.cpu().numpy()

    plt.figure(figsize=(15, 3))
    plt.suptitle("VAE Generated Images (from random latent samples) (PyTorch)", fontsize=16)
    for i in range(n):
        ax = plt.subplot(1, n, i + 1)
        plt.imshow(generated_images_np[i].reshape(28, 28), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

print("\nGenerating new images with VAE Decoder (PyTorch)...")
# Use the standalone decoder model accessed via vae.decoder
plot_generated_images(vae.decoder, n=15, latent_dim=LATENT_DIM_VAE)


### 6. Visualize the 2D Latent Space

If we set the `LATENT_DIM_VAE` to 2, we can visually inspect the learned latent space. This plot takes a subset of the test data, encodes it, and then plots the `z_mean` values in a 2D scatter plot. Each point is colored according to its original digit class. A well-trained VAE will show distinct clusters for different digit classes, indicating that it has learned a meaningful and separable representation of the data in its latent space. This continuity allows for smooth transitions between different generated digits.

In [None]:
print("\nVisualizing VAE 2D Latent Space (PyTorch)...")

# 2. Visualize the learned VAE Latent Space (if LATENT_DIM_VAE == 2)
def plot_latent_space(encoder_model, test_loader, latent_dim=LATENT_DIM_VAE):
    """Plots the 2D latent space (z_mean) colored by digit class for PyTorch."""
    if latent_dim != 2:
        print(f"Latent space visualization requires LATENT_DIM_VAE=2 (currently {latent_dim}). Skipping.")
        return

    encoder_model.eval() # Set encoder to evaluation mode
    z_mean_values = []
    y_labels = []

    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(DEVICE)
            # The encoder returns z_mean, z_log_var, z. We only need z_mean for this plot.
            z_mean, _, _ = encoder_model(data)
            z_mean_values.append(z_mean.cpu().numpy())
            y_labels.append(target.cpu().numpy())

    z_mean_values = np.concatenate(z_mean_values, axis=0)
    y_labels = np.concatenate(y_labels, axis=0)

    # Plot only a subset for clarity if dataset is too large
    num_samples_plot = min(10000, len(z_mean_values))

    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(z_mean_values[:num_samples_plot, 0],
                          z_mean_values[:num_samples_plot, 1],
                          c=y_labels[:num_samples_plot],
                          cmap='viridis', alpha=0.7, s=5)
    plt.colorbar(scatter, label='Digit Class')
    plt.xlabel("Latent Dimension 1 (z_mean)")
    plt.ylabel("Latent Dimension 2 (z_mean)")
    plt.title("VAE Latent Space (Mean Vectors - z_mean) (PyTorch)")
    plt.grid(True)
    plt.show()

if LATENT_DIM_VAE == 2:
    plot_latent_space(vae.encoder, test_loader, latent_dim=LATENT_DIM_VAE)
else:
    print(f"\nSkipping Latent Space plot because LATENT_DIM_VAE is {LATENT_DIM_VAE} (requires 2).")