# 🎓 VAE Tutorial on Fashion-MNIST

This notebook demonstrates how to train a Variational Autoencoder (VAE) on the Fashion-MNIST dataset using PyTorch. You’ll learn how the model works, how to train it, and how to visualize the latent space.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import os

## 🔧 Step 1: Define Hyperparameters and Device

In [None]:
# Hyperparameters
latent_dim = 2        # Dimensionality of latent space (2 for visualization)
hidden_dim = 500      # Hidden layer size for encoder and decoder MLP
batch_size = 128
learning_rate = 1e-3
epochs = 20

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 📦 Step 2: Load the Dataset

In [None]:
# Load Fashion-MNIST dataset
transform = transforms.ToTensor()  # Converts images to [0,1] range tensors
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

100%|██████████| 26.4M/26.4M [00:02<00:00, 12.3MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 210kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.87MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 12.0MB/s]


## 🧠 Step 3: Define the VAE Model

In [None]:
# VAE Model Definition
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # Encoder layers
        self.fc1 = nn.Linear(28*28, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)      # outputs mean μ
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # outputs log(variance)
        # Decoder layers
        self.fc_dec1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_dec2 = nn.Linear(hidden_dim, 28*28)

    def encode(self, x):
        """Encode image x into latent parameters (mu, logvar)."""
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """Sample z from N(mu, sigma^2) via reparameterization trick."""
        std = torch.exp(0.5 * logvar)        # standard deviation = exp(0.5*logvar)
        eps = torch.randn_like(std)          # draw random epsilon ~ N(0, I)
        z = mu + eps * std                   # shift and scale by mu and sigma
        return z

    def decode(self, z):
        """Decode latent vector z to reconstruct an image."""
        h = F.relu(self.fc_dec1(z))
        x_reconst = torch.sigmoid(self.fc_dec2(h))  # sigmoid to get pixel intensities 0-1
        return x_reconst

    def forward(self, x):
        """Perform encoding, reparameterization, and decoding."""
        mu, logvar = self.encode(x.view(-1, 28*28))  # flatten image to vector
        z = self.reparameterize(mu, logvar)
        x_reconst = self.decode(z)
        return x_reconst, mu, logvar

## ⚙️ Step 4: Initialize Model and Optimizer

In [None]:
# Initialize model and optimizer
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# VAE loss: binary cross-entropy (reconstruction) + Kullback-Leibler divergence
def vae_loss(recon_x, x, mu, logvar):
    # Reconstruction loss (BCE) summed over all pixels
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
    # KL Divergence loss: forces q(z|x) to approach p(z) = N(0,I)
    # Formula: 0.5 * sum( exp(logσ^2) + μ^2 - 1 - logσ^2 )
    KLD = 0.5 * torch.sum(torch.exp(logvar) + mu**2 - 1 - logvar)
    return BCE + KLD

## 💾 Step 6: Setup Output Directory

In [None]:
# Directory to save generated images
os.makedirs('vae_outputs', exist_ok=True)

## 🔁 Step 7: Training and Validation Functions

In [None]:
# Training and Validation loops
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)                   # forward pass through VAE
        loss = vae_loss(recon_batch, data, mu, logvar)          # compute VAE loss
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f"Train Epoch {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)}] Loss: {loss.item()/len(data):.4f}")
    avg_loss = train_loss / len(train_loader.dataset)
    print(f"===> Epoch {epoch} Complete: Avg Train Loss: {avg_loss:.4f}")

def validate(epoch):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss = vae_loss(recon_batch, data, mu, logvar)
            val_loss += loss.item()
            # Save input vs reconstruction comparison for the first batch
            if batch_idx == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n], recon_batch.view(-1,1,28,28)[:n]])
                utils.save_image(comparison.cpu(), f"vae_outputs/reconstruction_epoch{epoch}.png", nrow=n)
        avg_loss = val_loss / len(test_loader.dataset)
        print(f"===> Validation: Avg Loss: {avg_loss:.4f}")

## 🚀 Step 8: Train the VAE and Generate Samples

In [None]:
# Training loop with image generation
train_loss, val_loss = [], []
for epoch in range(1, epochs+1):
    train(epoch)
    validate(epoch)
    # Every few epochs, sample new images from the latent space prior
    if epoch % 5 == 0:
        model.eval()
        with torch.no_grad():
            # Sample 64 random latent vectors from N(0, I)
            z = torch.randn(64, latent_dim).to(device)
            sample_imgs = model.decode(z).view(-1, 1, 28, 28)
            utils.save_image(sample_imgs.cpu(), f"vae_outputs/sample_epoch{epoch}.png", nrow=8)
            print(f"[Saved 64 sampled images at epoch {epoch}]")

# After training, if latent_dim == 2, visualize the latent space structure
if latent_dim == 2:
    model.eval()
    zs = []
    labels = []
    with torch.no_grad():
        for data, label in test_loader:
            data = data.to(device)
            mu, logvar = model.encode(data.view(-1, 28*28))
            zs.append(mu.cpu())   # use mean as representative latent point
            labels += label.tolist()
    zs = torch.cat(zs, dim=0).numpy()
    labels = torch.tensor(labels).numpy()
    # Scatter plot of latent encodings colored by true class label
    plt.figure(figsize=(6,6))
    plt.scatter(zs[:,0], zs[:,1], c=labels, cmap='tab10', s=5, alpha=0.7)
    plt.colorbar().set_label("Fashion-MNIST class")
    plt.title("VAE latent space (2D) for test images")
    plt.xlabel("z1"); plt.ylabel("z2")
    plt.savefig("vae_outputs/latent_space_scatter.png")
    plt.close()
    # Also create a grid of images by traversing the 2D latent space
    grid_x = torch.linspace(-3, 3, steps=20)
    grid_y = torch.linspace(-3, 3, steps=20)
    imgs = []
    with torch.no_grad():
        for yi, yv in enumerate(grid_y):
            for xi, xv in enumerate(grid_x):
                z = torch.tensor([[xv.item(), yv.item()]], device=device)
                x_hat = model.decode(z).view(1,1,28,28)
                imgs.append(x_hat)
    imgs = torch.cat(imgs, dim=0)
    # Save 20x20 grid of generated images covering the latent space
    utils.save_image(imgs, "vae_outputs/latent_space_grid.png", nrow=20)
print("VAE training complete. Outputs saved to 'vae_outputs/' directory.")

Train Epoch 1 [0/60000] Loss: 553.1377
Train Epoch 1 [12800/60000] Loss: 281.6621
Train Epoch 1 [25600/60000] Loss: 299.7235
Train Epoch 1 [38400/60000] Loss: 275.5094
Train Epoch 1 [51200/60000] Loss: 278.4427
===> Epoch 1 Complete: Avg Train Loss: 291.2910
===> Validation: Avg Loss: 275.7749
Train Epoch 2 [0/60000] Loss: 269.3788
Train Epoch 2 [12800/60000] Loss: 270.6135
Train Epoch 2 [25600/60000] Loss: 261.4965
Train Epoch 2 [38400/60000] Loss: 285.4909
Train Epoch 2 [51200/60000] Loss: 269.4290
===> Epoch 2 Complete: Avg Train Loss: 272.2193
===> Validation: Avg Loss: 271.5556
Train Epoch 3 [0/60000] Loss: 269.3295
Train Epoch 3 [12800/60000] Loss: 280.1384
Train Epoch 3 [25600/60000] Loss: 278.4810
Train Epoch 3 [38400/60000] Loss: 256.9223
Train Epoch 3 [51200/60000] Loss: 266.5945
===> Epoch 3 Complete: Avg Train Loss: 268.9934
===> Validation: Avg Loss: 268.7676
Train Epoch 4 [0/60000] Loss: 268.6575
Train Epoch 4 [12800/60000] Loss: 285.8950
Train Epoch 4 [25600/60000] Loss: