# Stage 2 — DCGAN Analysis

This notebook analyzes the **Deep Convolutional GAN** (Radford et al., 2016)
and compares it against the Vanilla GAN baseline.

Key architectural improvements:
- Strided convolutions replace pooling
- BatchNorm stabilizes training
- No fully connected layers in the backbone
- Spectral Normalization in the Discriminator

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd()
if PROJECT_ROOT.name == "notebooks":
    PROJECT_ROOT = PROJECT_ROOT.parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from src.utils.config_loader import load_config, get_device
from src.utils.checkpointing import load_checkpoint, find_latest_checkpoint
from src.models.dcgan import build_dcgan
from src.data.dataloaders import get_dataloader_from_config
from src.evaluation.visualization import GANVisualizer

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100

print(f"PyTorch: {torch.__version__}")

## 1. Load Config & DCGAN Model

In [None]:
config = load_config(PROJECT_ROOT / "config" / "dcgan.yaml")
device = get_device(config)
print(f"Device: {device}")

generator, discriminator = build_dcgan(config)
generator = generator.to(device)
discriminator = discriminator.to(device)

print(f"Generator params:     {sum(p.numel() for p in generator.parameters()):>10,}")
print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):>10,}")

# Load checkpoint
ckpt_path = find_latest_checkpoint(config['paths']['models_dir'])
if ckpt_path:
    info = load_checkpoint(ckpt_path, generator, discriminator, device=device)
    print(f"\nLoaded: epoch {info['epoch']}, step {info['global_step']}")
else:
    print("\nNo checkpoint found — using untrained model.")

## 2. Architecture Visualization

**Generator** (64×64 output):
```
z [100, 1, 1]
  → ConvT 4×4, s1  → [512, 4, 4]  + BN + ReLU
  → ConvT 4×4, s2  → [256, 8, 8]  + BN + ReLU
  → ConvT 4×4, s2  → [128, 16, 16] + BN + ReLU
  → ConvT 4×4, s2  → [64, 32, 32]  + BN + ReLU
  → ConvT 4×4, s2  → [1, 64, 64]   + Tanh
```

**Discriminator** (64×64 input):
```
[1, 64, 64]
  → Conv 4×4, s2  → [64, 32, 32]  + LeakyReLU
  → Conv 4×4, s2  → [128, 16, 16] + LeakyReLU
  → Conv 4×4, s2  → [256, 8, 8]   + LeakyReLU
  → Conv 4×4, s2  → [512, 4, 4]   + LeakyReLU
  → Conv 4×4, s1  → [1, 1, 1]     (logit)
```

In [None]:
print("=== Generator ===")
print(generator)
print(f"\n=== Discriminator ===")
print(discriminator)

## 3. Generate & Visualize Samples

In [None]:
generator.eval()
latent_dim = config['model']['latent_dim']

with torch.no_grad():
    z = torch.randn(64, latent_dim, device=device)
    fake_images = generator(z)

grid = make_grid(fake_images.cpu(), nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(10, 10))
plt.imshow(grid.permute(1, 2, 0).numpy(), cmap='gray')
plt.title('DCGAN — Generated Samples (64×64)', fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.show()

## 4. Real vs DCGAN Comparison

In [None]:
dataloader = get_dataloader_from_config(config)
real_batch, _ = next(iter(dataloader))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

real_grid = make_grid(real_batch[:32], nrow=8, normalize=True, value_range=(-1, 1))
ax1.imshow(real_grid.permute(1, 2, 0).numpy(), cmap='gray')
ax1.set_title('Real MNIST (resized to 64×64)', fontsize=14)
ax1.axis('off')

with torch.no_grad():
    z = torch.randn(32, latent_dim, device=device)
    fakes = generator(z)
fake_grid = make_grid(fakes.cpu(), nrow=8, normalize=True, value_range=(-1, 1))
ax2.imshow(fake_grid.permute(1, 2, 0).numpy(), cmap='gray')
ax2.set_title('DCGAN Generated', fontsize=14)
ax2.axis('off')

plt.tight_layout()
plt.show()

## 5. Latent Space Interpolation (Slerp)

In [None]:
visualizer = GANVisualizer(output_dir=str(PROJECT_ROOT / "outputs"))

n_pairs = 6
n_steps = 12
all_interp = []

for _ in range(n_pairs):
    z1 = torch.randn(latent_dim, device=device)
    z2 = torch.randn(latent_dim, device=device)
    z_interp = visualizer.spherical_interpolation(z1, z2, n_steps)
    with torch.no_grad():
        imgs = generator(z_interp)
    all_interp.append(imgs)

interp_grid = make_grid(
    torch.cat(all_interp).cpu(), nrow=n_steps,
    normalize=True, value_range=(-1, 1)
)

plt.figure(figsize=(18, 9))
plt.imshow(interp_grid.permute(1, 2, 0).numpy(), cmap='gray')
plt.title('DCGAN — Spherical Interpolation in Latent Space', fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.show()

## 6. Feature Map Visualization

Inspect internal feature maps of the discriminator to see what it learns.

In [None]:
# Hook into discriminator first conv layer
activations = {}

def hook_fn(name):
    def hook(module, input, output):
        activations[name] = output.detach().cpu()
    return hook

# Register hooks on first few conv layers
hooks = []
for i, layer in enumerate(discriminator.net):
    if isinstance(layer, torch.nn.Conv2d) and i < 6:
        h = layer.register_forward_hook(hook_fn(f"conv_{i}"))
        hooks.append(h)

# Forward pass
with torch.no_grad():
    sample = real_batch[:1].to(device)
    _ = discriminator(sample)

# Plot feature maps
for name, feat in activations.items():
    n_maps = min(16, feat.shape[1])
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(f'Feature maps: {name} ({feat.shape[1]} channels)', fontsize=12)
    for idx, ax in enumerate(axes.flat):
        if idx < n_maps:
            ax.imshow(feat[0, idx].numpy(), cmap='viridis')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# Remove hooks
for h in hooks:
    h.remove()

## 7. Quantitative Evaluation

In [None]:
from src.evaluation.fid_score import FIDCalculator
from src.evaluation.inception_score import InceptionScoreCalculator

N_EVAL = 1000

# Real images
real_list = []
for imgs, _ in dataloader:
    real_list.append(imgs)
    if sum(x.size(0) for x in real_list) >= N_EVAL:
        break
real_imgs = torch.cat(real_list)[:N_EVAL]

# Fake images  
fake_list = []
remaining = N_EVAL
while remaining > 0:
    bs = min(64, remaining)
    z = torch.randn(bs, latent_dim, device=device)
    with torch.no_grad():
        fakes = generator(z)
    fake_list.append(fakes.cpu())
    remaining -= bs
fake_imgs = torch.cat(fake_list)[:N_EVAL]

print("Computing FID...")
fid_calc = FIDCalculator(device=str(device), batch_size=32)
fid = fid_calc.compute_fid(real_imgs, fake_imgs)
print(f"DCGAN FID: {fid:.2f}")

print("\nComputing IS...")
is_calc = InceptionScoreCalculator(device=str(device), batch_size=32)
is_mean, is_std = is_calc.compute_inception_score(fake_imgs, splits=5)
print(f"DCGAN IS: {is_mean:.2f} ± {is_std:.2f}")