# Variational Autoencoder (VAE) Tutorial

This notebook demonstrates VAE for generative modeling on MNIST dataset.

In [None]:
import sys
sys.path.append('../..')

import torch
import matplotlib.pyplot as plt

from src.generative import VAE, VAETrainer
from src.utils import load_mnist, get_device, set_seed, plot_training_curves, plot_image_grid, plot_reconstruction

set_seed(42)
device = get_device()
print(f"Using device: {device}")

## Load MNIST Dataset

In [None]:
# Load MNIST data
train_loader = load_mnist(batch_size=128, train=True, download=True)
test_loader = load_mnist(batch_size=128, train=False, download=True)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

## Create and Train VAE

In [None]:
# Create VAE model
model = VAE(
    input_dim=784,  # 28x28 images flattened
    latent_dim=20,
    hidden_dims=[256, 128]
)

print(model)

In [None]:
# Train model
trainer = VAETrainer(model, device=device)
history = trainer.train(
    train_loader,
    n_epochs=10,
    learning_rate=1e-3,
    val_loader=test_loader,
    beta=1.0,  # beta-VAE parameter
    verbose=True
)

## Plot Training History

In [None]:
# Plot loss curves
plot_training_curves(history['train_losses'], history['val_losses'], title='VAE Loss Curves')

## Test Reconstruction

In [None]:
# Get a batch of test data
dataiter = iter(test_loader)
images, _ = next(dataiter)
images = images.to(device)

# Reconstruct images
model.eval()
with torch.no_grad():
    images_flat = images.view(images.size(0), -1)
    # Normalize to [0, 1]
    images_flat = (images_flat - images_flat.min()) / (images_flat.max() - images_flat.min())
    reconstructed, _, _ = model(images_flat)
    reconstructed = reconstructed.view(-1, 1, 28, 28)

# Plot original vs reconstructed
plot_reconstruction(images[:8], reconstructed[:8], n_images=8)

## Generate New Samples

In [None]:
# Generate samples from the latent space
n_samples = 32
samples = model.sample(n_samples, device=device)
samples = samples.view(-1, 1, 28, 28)

# Plot generated samples
plot_image_grid(samples, n_rows=4, n_cols=8, title='Generated Samples from VAE')

## Latent Space Visualization

In [None]:
# Encode test data to visualize latent space
model.eval()
latent_vectors = []
labels_list = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        images_flat = images.view(images.size(0), -1)
        images_flat = (images_flat - images_flat.min()) / (images_flat.max() - images_flat.min())
        
        mu, _ = model.encode(images_flat)
        latent_vectors.append(mu.cpu())
        labels_list.append(labels)
        
        if len(latent_vectors) * images.size(0) >= 1000:  # Limit to 1000 samples
            break

latent_vectors = torch.cat(latent_vectors)
labels_all = torch.cat(labels_list)

print(f"Latent vectors shape: {latent_vectors.shape}")

In [None]:
# Visualize first 2 dimensions of latent space
plt.figure(figsize=(10, 8))
scatter = plt.scatter(latent_vectors[:, 0], latent_vectors[:, 1], 
                     c=labels_all, cmap='tab10', alpha=0.6, edgecolors='k', s=20)
plt.colorbar(scatter, label='Digit')
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('VAE Latent Space Visualization (First 2 Dimensions)')
plt.grid(True, alpha=0.3)
plt.show()

## Latent Space Interpolation

In [None]:
# Interpolate between two random points in latent space
model.eval()
with torch.no_grad():
    # Sample two random points
    z1 = torch.randn(1, model.latent_dim).to(device)
    z2 = torch.randn(1, model.latent_dim).to(device)
    
    # Interpolate
    n_steps = 10
    alphas = torch.linspace(0, 1, n_steps)
    interpolated = []
    
    for alpha in alphas:
        z = (1 - alpha) * z1 + alpha * z2
        sample = model.decode(z)
        interpolated.append(sample.view(1, 1, 28, 28))
    
    interpolated = torch.cat(interpolated)

# Plot interpolation
fig, axes = plt.subplots(1, n_steps, figsize=(20, 2))
for i, ax in enumerate(axes):
    ax.imshow(interpolated[i].cpu().squeeze(), cmap='gray')
    ax.axis('off')
    ax.set_title(f'{alphas[i]:.1f}')
plt.suptitle('Latent Space Interpolation', fontsize=16)
plt.tight_layout()
plt.show()