# PyTorch Tutorial: Autoencoders and Variational Autoencoders (VAEs)

Autoencoders are neural networks that learn to compress data into a lower-dimensional **latent space** and then reconstruct it. Variational Autoencoders (VAEs) extend this idea by learning a **probabilistic** latent space, enabling generation of new data.

## Learning Objectives
- Understand the encoder-decoder architecture
- Implement a vanilla autoencoder for image reconstruction
- Learn the VAE reparameterization trick and ELBO loss
- Visualize and navigate latent spaces
- Implement beta-VAE for disentangled representations
- Build a Conditional VAE (CVAE) for controlled generation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Data Loading

We'll use MNIST for simplicity - 28x28 grayscale images of handwritten digits.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 2. Vanilla Autoencoder

An autoencoder has two parts:
- **Encoder**: Compresses input x to latent representation z
- **Decoder**: Reconstructs x_hat from z

The bottleneck (latent dimension) forces the network to learn meaningful features.

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        
        # Encoder: 784 -> 256 -> 128 -> latent_dim
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )
        
        # Decoder: latent_dim -> 128 -> 256 -> 784
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid(),
        )
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z).view(-1, 1, 28, 28)
    
    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)

ae = Autoencoder(latent_dim=32).to(device)
print(f"Autoencoder parameters: {sum(p.numel() for p in ae.parameters()):,}")

### Training the Autoencoder

Loss: Mean Squared Error between input and reconstruction.

In [None]:
def train_autoencoder(model, train_loader, epochs=10, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            
            optimizer.zero_grad()
            recon = model(data)
            loss = F.mse_loss(recon, data)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    
    return model

ae = train_autoencoder(ae, train_loader, epochs=10)

### Visualize Reconstructions

In [None]:
def visualize_reconstructions(model, test_loader, n=10):
    model.eval()
    data, _ = next(iter(test_loader))
    data = data[:n].to(device)
    
    with torch.no_grad():
        recon = model(data)
    
    fig, axes = plt.subplots(2, n, figsize=(15, 3))
    for i in range(n):
        axes[0, i].imshow(data[i].cpu().squeeze(), cmap='gray')
        axes[0, i].axis('off')
        axes[0, i].set_title('Original')
        
        axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
        axes[1, i].axis('off')
        axes[1, i].set_title('Reconstructed')
    
    plt.tight_layout()
    plt.show()

visualize_reconstructions(ae, test_loader)

## 3. Variational Autoencoder (VAE)

The key difference: VAE learns a **probability distribution** in latent space, not just point embeddings.

### The Reparameterization Trick

Instead of encoding to a single point z, we encode to:
- Mean: mu
- Log variance: log(sigma^2)

Then sample: z = mu + sigma * epsilon, where epsilon ~ N(0, 1)

This allows gradients to flow through the sampling operation!

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )
        
        # Outputs mu and log_var
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid(),
        )
    
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick: z = mu + std * epsilon"""
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        return mu + std * epsilon
    
    def decode(self, z):
        return self.decoder(z).view(-1, 1, 28, 28)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

vae = VAE(latent_dim=2).to(device)
print(f"VAE parameters: {sum(p.numel() for p in vae.parameters()):,}")

### ELBO Loss

The VAE loss has two parts:

L = E[log p(x|z)] - D_KL(q(z|x) || p(z))

- **Reconstruction loss**: How well can we reconstruct the input?
- **KL Divergence**: How close is the latent distribution to a standard normal?

In [None]:
def vae_loss(recon, x, mu, logvar):
    """ELBO loss = Reconstruction + KL Divergence"""
    # Reconstruction loss (binary cross entropy works well for normalized images)
    recon_loss = F.binary_cross_entropy(recon.view(-1, 784), x.view(-1, 784), reduction='sum')
    
    # KL Divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + kl_loss

def train_vae(model, train_loader, epochs=20, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            
            optimizer.zero_grad()
            recon, mu, logvar = model(data)
            loss = vae_loss(recon, data, mu, logvar)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.2f}")
    
    return model

vae = train_vae(vae, train_loader, epochs=20)

## 4. Latent Space Visualization

With a 2D latent space, we can visualize where different digits are encoded!

In [None]:
def plot_latent_space(model, test_loader):
    model.eval()
    z_points = []
    labels = []
    
    with torch.no_grad():
        for data, label in test_loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            z_points.append(mu.cpu())
            labels.append(label)
    
    z_points = torch.cat(z_points).numpy()
    labels = torch.cat(labels).numpy()
    
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(z_points[:, 0], z_points[:, 1], c=labels, cmap='tab10', alpha=0.5, s=2)
    plt.colorbar(scatter, label='Digit')
    plt.xlabel('z[0]')
    plt.ylabel('z[1]')
    plt.title('VAE Latent Space (2D)')
    plt.show()

plot_latent_space(vae, test_loader)

### Generate from Latent Space Grid

We can sample a grid of points in latent space and decode them to see what the model has learned.

In [None]:
def plot_latent_grid(model, n=20, range_val=3):
    model.eval()
    
    # Create grid of latent points
    grid_x = np.linspace(-range_val, range_val, n)
    grid_y = np.linspace(-range_val, range_val, n)
    
    figure = np.zeros((28 * n, 28 * n))
    
    with torch.no_grad():
        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z = torch.tensor([[xi, yi]], dtype=torch.float32).to(device)
                decoded = model.decode(z)
                digit = decoded[0].cpu().squeeze().numpy()
                figure[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = digit
    
    plt.figure(figsize=(12, 12))
    plt.imshow(figure, cmap='gray')
    plt.title('Generated Digits from Latent Space Grid')
    plt.xlabel('z[0]')
    plt.ylabel('z[1]')
    plt.axis('off')
    plt.show()

plot_latent_grid(vae)

### Latent Space Interpolation

One powerful property of VAEs: we can smoothly interpolate between two images in latent space!

In [None]:
def interpolate(model, x1, x2, steps=10):
    model.eval()
    
    with torch.no_grad():
        mu1, _ = model.encode(x1.to(device))
        mu2, _ = model.encode(x2.to(device))
        
        # Linear interpolation in latent space
        interpolations = []
        for alpha in np.linspace(0, 1, steps):
            z = (1 - alpha) * mu1 + alpha * mu2
            decoded = model.decode(z)
            interpolations.append(decoded.cpu())
    
    # Plot
    fig, axes = plt.subplots(1, steps, figsize=(15, 2))
    for i, img in enumerate(interpolations):
        axes[i].imshow(img.squeeze(), cmap='gray')
        axes[i].axis('off')
    plt.suptitle('Latent Space Interpolation')
    plt.tight_layout()
    plt.show()

# Get two different digits
test_data, test_labels = next(iter(test_loader))
idx1 = (test_labels == 3).nonzero()[0].item()
idx2 = (test_labels == 8).nonzero()[0].item()

interpolate(vae, test_data[idx1:idx1+1], test_data[idx2:idx2+1])

## 5. Beta-VAE for Disentangled Representations

Beta-VAE adds a weight to the KL term to encourage **disentangled** latent factors:

L = Reconstruction - beta * D_KL

- beta = 1: Standard VAE
- beta > 1: Stronger regularization, more disentangled, but worse reconstruction
- beta < 1: Better reconstruction, less regularized latent space

In [None]:
def beta_vae_loss(recon, x, mu, logvar, beta=4.0):
    """Beta-VAE loss with adjustable KL weight"""
    recon_loss = F.binary_cross_entropy(recon.view(-1, 784), x.view(-1, 784), reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss

def train_beta_vae(model, train_loader, beta=4.0, epochs=20, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            
            optimizer.zero_grad()
            recon, mu, logvar = model(data)
            loss = beta_vae_loss(recon, data, mu, logvar, beta=beta)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader.dataset)
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.2f}")
    
    return model

# Train with different beta values
print("Training Beta-VAE with beta=4.0...")
beta_vae = VAE(latent_dim=10).to(device)
beta_vae = train_beta_vae(beta_vae, train_loader, beta=4.0, epochs=20)

### Traverse Individual Latent Dimensions

With disentangled representations, each latent dimension should control a single factor of variation.

In [None]:
def latent_traversal(model, base_z, dim, range_val=3, steps=10):
    """Traverse a single latent dimension while keeping others fixed"""
    model.eval()
    
    images = []
    with torch.no_grad():
        for val in np.linspace(-range_val, range_val, steps):
            z = base_z.clone()
            z[0, dim] = val
            decoded = model.decode(z)
            images.append(decoded.cpu().squeeze())
    
    return images

# Get a base encoding
test_img = test_data[0:1].to(device)
with torch.no_grad():
    base_mu, _ = beta_vae.encode(test_img)

# Traverse first 5 latent dimensions
fig, axes = plt.subplots(5, 10, figsize=(15, 8))
for dim in range(5):
    images = latent_traversal(beta_vae, base_mu.clone(), dim)
    for j, img in enumerate(images):
        axes[dim, j].imshow(img, cmap='gray')
        axes[dim, j].axis('off')
    axes[dim, 0].set_ylabel(f'z[{dim}]', fontsize=12)

plt.suptitle('Latent Dimension Traversal (Beta-VAE)', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Conditional VAE (CVAE)

A CVAE conditions on additional information (like class labels) to enable controlled generation.

In [None]:
class CVAE(nn.Module):
    def __init__(self, latent_dim=2, num_classes=10):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Encoder: image + one-hot label
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28 + num_classes, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        
        # Decoder: latent + one-hot label
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid(),
        )
    
    def encode(self, x, y):
        y_onehot = F.one_hot(y, self.num_classes).float()
        x_flat = x.view(-1, 28 * 28)
        h = self.encoder(torch.cat([x_flat, y_onehot], dim=1))
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        return mu + std * epsilon
    
    def decode(self, z, y):
        y_onehot = F.one_hot(y, self.num_classes).float()
        h = torch.cat([z, y_onehot], dim=1)
        return self.decoder(h).view(-1, 1, 28, 28)
    
    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, y)
        return recon, mu, logvar

cvae = CVAE(latent_dim=2, num_classes=10).to(device)
print(f"CVAE parameters: {sum(p.numel() for p in cvae.parameters()):,}")

In [None]:
def train_cvae(model, train_loader, epochs=20, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, labels) in enumerate(train_loader):
            data, labels = data.to(device), labels.to(device)
            
            optimizer.zero_grad()
            recon, mu, logvar = model(data, labels)
            loss = vae_loss(recon, data, mu, logvar)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader.dataset)
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.2f}")
    
    return model

cvae = train_cvae(cvae, train_loader, epochs=20)

### Generate Specific Digits

In [None]:
def generate_digits(model, n_per_digit=5):
    model.eval()
    
    fig, axes = plt.subplots(10, n_per_digit, figsize=(10, 20))
    
    with torch.no_grad():
        for digit in range(10):
            for i in range(n_per_digit):
                z = torch.randn(1, model.latent_dim).to(device)
                y = torch.tensor([digit]).to(device)
                generated = model.decode(z, y)
                
                axes[digit, i].imshow(generated.cpu().squeeze(), cmap='gray')
                axes[digit, i].axis('off')
            axes[digit, 0].set_ylabel(f'{digit}', fontsize=14)
    
    plt.suptitle('CVAE: Generated Digits by Class', fontsize=14)
    plt.tight_layout()
    plt.show()

generate_digits(cvae)

## 7. FAANG Interview Questions

### Q1: What is the reparameterization trick and why is it necessary?

**Answer**:

The reparameterization trick allows us to backpropagate through a sampling operation.

**Problem**: We want to sample z from q(z|x) = N(mu, sigma^2), but sampling is non-differentiable.

**Solution**: Express sampling as a deterministic function:
z = mu + sigma * epsilon, where epsilon ~ N(0, 1)

Now:
- mu and sigma are outputs of the encoder (differentiable)
- epsilon is sampled from a fixed distribution (no gradients needed)
- Gradients can flow from the loss through z back to the encoder

---

### Q2: Explain the ELBO loss and why VAEs use it.

**Answer**:

ELBO = Evidence Lower BOund. We want to maximize log p(x) but it's intractable.

log p(x) >= E[log p(x|z)] - D_KL(q(z|x) || p(z))

- **Reconstruction term**: The decoder should reconstruct the input well
- **KL term**: The encoder distribution should be close to the prior p(z) = N(0, I)

The KL term prevents the model from just memorizing inputs - it forces a structured latent space.

---

### Q3: What is posterior collapse and how do you prevent it?

**Answer**:

**Posterior collapse**: The encoder ignores the input and outputs the prior (q(z|x) approx p(z)), while the decoder ignores z and models p(x) directly.

**Causes**:
- Powerful decoder (e.g., autoregressive) that doesn't need z
- KL term dominates early in training

**Solutions**:
1. **KL annealing**: Start with low KL weight, gradually increase
2. **Free bits**: Minimum information that must be encoded
3. **Beta-VAE with beta < 1**: Weaker regularization
4. **Weaker decoders**: Force the model to use z

---

### Q4: Compare VAE vs GAN for generative modeling.

**Answer**:

| Aspect | VAE | GAN |
|--------|-----|-----|
| **Training** | Stable, uses reconstruction | Adversarial, can be unstable |
| **Sample quality** | Often blurry | Sharp, realistic |
| **Latent space** | Smooth, continuous | Often discontinuous |
| **Inference** | Has encoder (can get z from x) | No encoder by default |
| **Mode coverage** | Good (covers all modes) | May suffer mode collapse |
| **Density estimation** | Provides likelihood bound | No density |

**Use VAE when**: You need a latent space, inference, or smooth interpolation.
**Use GAN when**: Sample quality is paramount.

---

### Q5: What is a Beta-VAE and what does disentanglement mean?

**Answer**:

**Beta-VAE**: Modifies VAE loss with L = Recon - beta * D_KL

**Disentanglement**: Each latent dimension controls ONE independent factor of variation.

Example (faces): One dimension controls hair color, another controls smile, another controls age - independently.

**Why beta > 1 helps**:
- Stronger constraint to match prior N(0, I)
- Independent latent dimensions (prior has independent factors)
- Trade-off: Worse reconstruction for better disentanglement

## 8. Key Takeaways

1. **Autoencoders** learn compressed representations through encoder-decoder architecture
2. **VAEs** extend AEs with probabilistic latent spaces using the reparameterization trick
3. **ELBO loss** balances reconstruction quality with latent space regularity
4. **Latent space** enables interpolation, generation, and understanding of data structure
5. **Beta-VAE** (beta > 1) encourages disentangled representations at the cost of reconstruction
6. **CVAE** enables conditional generation by incorporating class information
7. **Posterior collapse** is a key failure mode - mitigate with KL annealing or architectural choices