# Stage 1 — Vanilla GAN Analysis

This notebook analyzes the training and output quality of the **MLP-based Vanilla GAN** on MNIST.

Sections:
1. Setup & imports
2. Load config and trained model
3. Visualize generated samples
4. Training loss curves
5. Real vs Generated comparison
6. Latent space interpolation
7. FID & Inception Score evaluation

In [None]:
import sys
from pathlib import Path

# Project root
PROJECT_ROOT = Path.cwd()
if PROJECT_ROOT.name == "notebooks":
    PROJECT_ROOT = PROJECT_ROOT.parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from IPython.display import display, Image as IPImage

from src.utils.config_loader import load_config, get_device
from src.utils.checkpointing import load_checkpoint, find_latest_checkpoint
from src.models.vanilla_gan import build_vanilla_gan
from src.data.dataloaders import get_dataloader_from_config
from src.evaluation.visualization import GANVisualizer

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100

print(f"PyTorch version: {torch.__version__}")
print(f"Project root: {PROJECT_ROOT}")

## 1. Load Configuration & Model

In [None]:
# Load config
config = load_config(PROJECT_ROOT / "config" / "vanilla_gan.yaml")
device = get_device(config)
print(f"Device: {device}")
print(f"Experiment: {config['experiment']['name']}")

# Build models
generator, discriminator = build_vanilla_gan(config)
generator = generator.to(device)
discriminator = discriminator.to(device)

print(f"\nGenerator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

In [None]:
# Load trained checkpoint
ckpt_dir = Path(config['paths']['models_dir'])
ckpt_path = find_latest_checkpoint(ckpt_dir)

if ckpt_path:
    info = load_checkpoint(ckpt_path, generator, discriminator, device=device)
    print(f"Loaded checkpoint: epoch {info['epoch']}, step {info['global_step']}")
    print(f"Metrics: {info['metrics']}")
else:
    print("No checkpoint found — using untrained model for demonstration")

## 2. Generate & Visualize Samples

In [None]:
# Generate a grid of samples
generator.eval()
latent_dim = config['model']['latent_dim']

with torch.no_grad():
    z = torch.randn(64, latent_dim, device=device)
    fake_images = generator(z)

# Display grid
grid = make_grid(fake_images.cpu(), nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(10, 10))
plt.imshow(grid.permute(1, 2, 0).numpy(), cmap='gray')
plt.title('Vanilla GAN — Generated Samples')
plt.axis('off')
plt.tight_layout()
plt.show()

## 3. Real vs Generated Comparison

In [None]:
# Load real data
dataloader = get_dataloader_from_config(config)
real_batch, _ = next(iter(dataloader))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

# Real
real_grid = make_grid(real_batch[:32], nrow=8, normalize=True, value_range=(-1, 1))
ax1.imshow(real_grid.permute(1, 2, 0).numpy(), cmap='gray')
ax1.set_title('Real MNIST Images', fontsize=14)
ax1.axis('off')

# Fake
with torch.no_grad():
    z = torch.randn(32, latent_dim, device=device)
    fakes = generator(z)
fake_grid = make_grid(fakes.cpu(), nrow=8, normalize=True, value_range=(-1, 1))
ax2.imshow(fake_grid.permute(1, 2, 0).numpy(), cmap='gray')
ax2.set_title('Generated Images (Vanilla GAN)', fontsize=14)
ax2.axis('off')

plt.tight_layout()
plt.show()

## 4. Training Loss Curves

If training has been run with TensorBoard logging, we can visualize the loss curves.

In [None]:
# Note: Run this after training. The training loop saves loss history.
# You can also launch TensorBoard:
#   tensorboard --logdir outputs/tensorboard

# Placeholder: if you have saved losses, plot them
# Example with synthetic data for demonstration:
n_steps = 1000
g_losses_demo = np.random.exponential(0.7, n_steps).cumsum() / np.arange(1, n_steps + 1)
d_losses_demo = np.random.exponential(0.7, n_steps).cumsum() / np.arange(1, n_steps + 1)

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(g_losses_demo, label='Generator Loss', alpha=0.7)
ax.plot(d_losses_demo, label='Discriminator Loss', alpha=0.7)
ax.set_xlabel('Training Step')
ax.set_ylabel('Loss')
ax.set_title('Vanilla GAN — Training Loss Curves')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Tip: Run `tensorboard --logdir outputs/tensorboard` for interactive plots")

## 5. Latent Space Interpolation

Smooth transitions between two random latent vectors demonstrate
the generator has learned a meaningful latent representation.

In [None]:
visualizer = GANVisualizer(output_dir=str(PROJECT_ROOT / "outputs"))

# Generate interpolation
n_pairs = 5
n_steps = 10

generator.eval()
all_interp = []

for _ in range(n_pairs):
    z1 = torch.randn(latent_dim, device=device)
    z2 = torch.randn(latent_dim, device=device)
    z_interp = visualizer.spherical_interpolation(z1, z2, n_steps)
    with torch.no_grad():
        imgs = generator(z_interp)
    all_interp.append(imgs)

interp_grid = make_grid(
    torch.cat(all_interp, dim=0).cpu(),
    nrow=n_steps, normalize=True, value_range=(-1, 1)
)

plt.figure(figsize=(16, 8))
plt.imshow(interp_grid.permute(1, 2, 0).numpy(), cmap='gray')
plt.title('Latent Space Interpolation (Spherical)', fontsize=14)
plt.xlabel('Interpolation steps →')
plt.ylabel('Different pairs ↓')
plt.axis('off')
plt.tight_layout()
plt.show()

## 6. Quantitative Evaluation (FID & IS)

**FID Score**: Measures the Fréchet distance between feature distributions of real and generated images.  
$$FID = ||\mu_r - \mu_g||^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2})$$

**Inception Score**: Measures quality (sharp predictions) and diversity (varied predictions).  
$$IS = \exp(\mathbb{E}_x[KL(p(y|x) \| p(y))])$$

In [None]:
from src.evaluation.fid_score import FIDCalculator
from src.evaluation.inception_score import InceptionScoreCalculator

N_EVAL_SAMPLES = 1000  # Use 5000+ for accurate results

# Collect real images
real_list = []
count = 0
for imgs, _ in dataloader:
    real_list.append(imgs)
    count += imgs.size(0)
    if count >= N_EVAL_SAMPLES:
        break
real_images = torch.cat(real_list)[:N_EVAL_SAMPLES]

# Generate fake images
fake_list = []
remaining = N_EVAL_SAMPLES
while remaining > 0:
    bs = min(64, remaining)
    z = torch.randn(bs, latent_dim, device=device)
    with torch.no_grad():
        fakes = generator(z)
    fake_list.append(fakes.cpu())
    remaining -= bs
fake_images = torch.cat(fake_list)[:N_EVAL_SAMPLES]

print(f"Real: {real_images.shape}, Fake: {fake_images.shape}")

In [None]:
# Compute FID (requires InceptionV3 — may take a while on first run)
print("Computing FID Score...")
fid_calc = FIDCalculator(device=str(device), batch_size=32)
fid = fid_calc.compute_fid(real_images, fake_images)
print(f"FID Score: {fid:.2f}")

print("\nComputing Inception Score...")
is_calc = InceptionScoreCalculator(device=str(device), batch_size=32)
is_mean, is_std = is_calc.compute_inception_score(fake_images, splits=5)
print(f"Inception Score: {is_mean:.2f} ± {is_std:.2f}")

## 7. Architecture Summary

| Component | Architecture | Parameters |
|-----------|-------------|------------|
| Generator | MLP: 100 → 256 → 512 → 1024 → 784 | ~1.3M |
| Discriminator | MLP: 784 → 512 → 256 → 1 | ~0.5M |
| Loss | BCE (Binary Cross-Entropy) with logits | — |
| Stability | Spectral Norm (D), Label Smoothing, Dropout | — |