# Stage 3 — Conditional GAN Demo

This notebook demonstrates **class-conditioned image generation** using the
Conditional DCGAN with a Projection Discriminator.

Key features:
- Generate specific digit classes on demand
- Projection Discriminator (Miyato & Koyama, 2018)
- Class embedding concatenated with latent vector
- Per-class quality analysis

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd()
if PROJECT_ROOT.name == "notebooks":
    PROJECT_ROOT = PROJECT_ROOT.parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from src.utils.config_loader import load_config, get_device
from src.utils.checkpointing import load_checkpoint, find_latest_checkpoint
from src.models.conditional_gan import build_conditional_gan
from src.data.dataloaders import get_dataloader_from_config
from src.evaluation.visualization import GANVisualizer

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100

print(f"PyTorch: {torch.__version__}")

## 1. Load Model

In [None]:
config = load_config(PROJECT_ROOT / "config" / "conditional_gan.yaml")
device = get_device(config)
latent_dim = config['model']['latent_dim']
n_classes = config['model']['n_classes']

generator, discriminator = build_conditional_gan(config)
generator = generator.to(device)
discriminator = discriminator.to(device)

print(f"Device: {device}")
print(f"Classes: {n_classes}")
print(f"Generator params:     {sum(p.numel() for p in generator.parameters()):>10,}")
print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):>10,}")

ckpt_path = find_latest_checkpoint(config['paths']['models_dir'])
if ckpt_path:
    info = load_checkpoint(ckpt_path, generator, discriminator, device=device)
    print(f"Loaded: epoch {info['epoch']}")
else:
    print("No checkpoint found.")

## 2. Class-Conditioned Generation Grid

Each row = one class (0–9). Same noise vector across columns to see
how the same latent code is interpreted per class.

In [None]:
generator.eval()
n_samples_per_class = 10

# Use same noise for all classes → shows class control
fixed_z = torch.randn(n_samples_per_class, latent_dim, device=device)
all_images = []

for class_idx in range(n_classes):
    labels = torch.full((n_samples_per_class,), class_idx, dtype=torch.long, device=device)
    with torch.no_grad():
        imgs = generator(fixed_z, labels)
    all_images.append(imgs)

grid = make_grid(
    torch.cat(all_images).cpu(),
    nrow=n_samples_per_class,
    normalize=True, value_range=(-1, 1), padding=2
)

plt.figure(figsize=(14, 14))
plt.imshow(grid.permute(1, 2, 0).numpy(), cmap='gray')
plt.title('Conditional GAN — Same Noise, Different Classes (rows 0-9)', fontsize=14)

# Add class labels on y-axis
img_size = 64 + 2  # padding
for i in range(n_classes):
    plt.text(-15, i * img_size + img_size // 2, str(i),
             fontsize=12, ha='center', va='center', fontweight='bold')

plt.axis('off')
plt.tight_layout()
plt.show()

## 3. Generate Specific Digits

In [None]:
def generate_digit(digit: int, n: int = 16):
    """Generate N images of a specific digit."""
    z = torch.randn(n, latent_dim, device=device)
    labels = torch.full((n,), digit, dtype=torch.long, device=device)
    with torch.no_grad():
        return generator(z, labels).cpu()

# Generate specific digits
for digit in [0, 3, 7, 9]:
    imgs = generate_digit(digit, n=8)
    grid = make_grid(imgs, nrow=8, normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(12, 2))
    plt.imshow(grid.permute(1, 2, 0).numpy(), cmap='gray')
    plt.title(f'Generated digit: {digit}', fontsize=12)
    plt.axis('off')
    plt.show()

## 4. Class-Conditioned Interpolation

Interpolate in latent space while keeping the class label fixed.

In [None]:
visualizer = GANVisualizer(output_dir=str(PROJECT_ROOT / "outputs"))
n_steps = 12

fig, axes = plt.subplots(n_classes, 1, figsize=(16, n_classes * 1.5))
fig.suptitle('Latent Interpolation per Class', fontsize=14, y=1.01)

for class_idx in range(n_classes):
    z1 = torch.randn(latent_dim, device=device)
    z2 = torch.randn(latent_dim, device=device)
    z_interp = visualizer.spherical_interpolation(z1, z2, n_steps)
    labels = torch.full((n_steps,), class_idx, dtype=torch.long, device=device)

    with torch.no_grad():
        imgs = generator(z_interp, labels)

    grid = make_grid(imgs.cpu(), nrow=n_steps, normalize=True, value_range=(-1, 1))
    axes[class_idx].imshow(grid.permute(1, 2, 0).numpy(), cmap='gray')
    axes[class_idx].set_ylabel(str(class_idx), fontsize=12, rotation=0, labelpad=15)
    axes[class_idx].set_xticks([])
    axes[class_idx].set_yticks([])

plt.tight_layout()
plt.show()

## 5. Label Morphing — Smooth Transition Between Classes

Fix the noise vector and change the class label to see how the generator
interprets different classes for the same latent code.

In [None]:
# Fixed noise, varying labels
n_noise_samples = 5
all_morphs = []

for i in range(n_noise_samples):
    z = torch.randn(1, latent_dim, device=device).expand(n_classes, -1)
    labels = torch.arange(n_classes, device=device)
    with torch.no_grad():
        imgs = generator(z, labels)
    all_morphs.append(imgs)

morph_grid = make_grid(
    torch.cat(all_morphs).cpu(),
    nrow=n_classes, normalize=True, value_range=(-1, 1)
)

plt.figure(figsize=(14, 7))
plt.imshow(morph_grid.permute(1, 2, 0).numpy(), cmap='gray')
plt.title('Label Morphing — Same noise (rows), different classes (columns 0→9)', fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.show()

## 6. Per-Class FID Evaluation

In [None]:
from src.evaluation.fid_score import FIDCalculator
from src.evaluation.inception_score import InceptionScoreCalculator

N_EVAL = 500  # per class
dataloader = get_dataloader_from_config(config)

# Organize real images by class
real_by_class = {i: [] for i in range(n_classes)}
for imgs, labels in dataloader:
    for img, lbl in zip(imgs, labels):
        c = lbl.item()
        if len(real_by_class[c]) < N_EVAL:
            real_by_class[c].append(img)
    if all(len(v) >= N_EVAL for v in real_by_class.values()):
        break

print("Computing FID per class (this may take a while)...")
fid_calc = FIDCalculator(device=str(device), batch_size=32)
class_fids = {}

for c in range(n_classes):
    real_imgs = torch.stack(real_by_class[c][:N_EVAL])
    fake_imgs = generate_digit(c, N_EVAL)
    fid = fid_calc.compute_fid(real_imgs, fake_imgs)
    class_fids[c] = fid
    print(f"  Class {c}: FID = {fid:.2f}")

# Plot
plt.figure(figsize=(10, 5))
plt.bar(range(n_classes), [class_fids[c] for c in range(n_classes)])
plt.xlabel('Class')
plt.ylabel('FID Score')
plt.title('Per-Class FID Scores (Conditional GAN)')
plt.xticks(range(n_classes))
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

## 7. Summary — Three-Stage Comparison

| Stage | Architecture | Key Feature | Expected FID |
|-------|-------------|-------------|-------------|
| 1. Vanilla GAN | MLP | Baseline | ~150-200 |
| 2. DCGAN | ConvNet | Spatial hierarchy | ~50-100 |
| 3. Conditional GAN | cDCGAN + Projection D | Class control | ~30-80 |

Lower FID = better quality. The convolutional backbone and conditioning
provide substantial improvements over the MLP baseline.