## Introduction

In this assignment you will explore **Generative Adversarial Networks (GANs)** through a curated, step-by-step notebook. The notebook runs end-to-end out of the box so you can focus on the **ideas** and **observations**, not on boilerplate coding.

### Big picture (what you’ll see as you run cells)
- **Data → Model → Training → Monitoring → Samples**: we load a simple image dataset, train a generator and a critic with WGAN-GP, track a few curves, and visualize generated images over epochs.
- **Why WGAN-GP?** It encourages smooth, stable training by nudging the critic’s gradients toward unit norm, which tends to improve sample quality and reduce pathologies like mode collapse.
- **What to pay attention to:** how the samples evolve, how losses behave, and how the gradient penalty relates to stability.

### Your job in this notebook
1. **Run** the provided cells in order to understand the workflow and see the model learn.
2. **Observe & reflect** (brief notes):
   - How do samples improve across epochs?
   - What do the training curves suggest about stability?
   - Do you notice mode collapse or artifacts?

### Final task (what you will change)
- **Replace the dataset** (e.g., Fashion-MNIST or another small image dataset of similar size, can be your own dataset).
- **Train another GAN** configuration of your choice (e.g., switch objective or architecture at a high level; you can reuse most of the notebook flow).
- **Report**:
  - A few representative sample grids,
  - 1–2 concise plots you find most informative,
  - **Your own exploration:** try adjusting hyperparameters (e.g., learning rate, betas, batch size, critic updates), modifying architecture (depth/width, normalization, activations), or even proposing a new loss/objective (creative/random attempts are welcome). Summarize what you changed and what you observed.

> Keep your discussion conceptual and visual: focus on *what* changes and *why* you think it happens, rather than implementation details.


In [None]:
# === Environment & Utilities ===
import os, math, time
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output, display

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

def seed_everything(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)
print("Device:", device)

In [None]:
# === Dataloaders ===
def get_mnist_dataloaders(batch_size=128):
    tfm = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),              # (0,1)
    ])
    train = datasets.MNIST(root='./data/mnist', train=True,  download=True, transform=tfm)
    test  = datasets.MNIST(root='./data/mnist', train=False, download=True, transform=tfm)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,
                              num_workers=2, pin_memory=(device.type=='cuda'), drop_last=True)
    test_loader  = DataLoader(test,  batch_size=batch_size, shuffle=False,
                              num_workers=2, pin_memory=(device.type=='cuda'), drop_last=False)
    return train_loader, test_loader

def get_fashion_mnist_dataloaders(batch_size=128):
    tfm = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
    ])
    train = datasets.FashionMNIST(root='./data/fashion', train=True,  download=True, transform=tfm)
    test  = datasets.FashionMNIST(root='./data/fashion', train=False, download=True, transform=tfm)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,
                              num_workers=2, pin_memory=(device.type=='cuda'), drop_last=True)
    test_loader  = DataLoader(test,  batch_size=batch_size, shuffle=False,
                              num_workers=2, pin_memory=(device.type=='cuda'), drop_last=False)
    return train_loader, test_loader

def get_lsun_dataloader(path_to_data='./data/lsun', dataset='bedroom_train', batch_size=64):
    tfm = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor(),
    ])
    dset = datasets.LSUN(root=path_to_data, classes=[dataset], transform=tfm)
    loader = DataLoader(dset, batch_size=batch_size, shuffle=True,
                        num_workers=4, pin_memory=(device.type=='cuda'), drop_last=True)
    return loader

dataset_name = 'mnist'
BATCH_SIZE = 128

if dataset_name == 'mnist':
    train_loader, _ = get_mnist_dataloaders(batch_size=BATCH_SIZE)
    IMG_SIZE = (32, 32, 1)
elif dataset_name == 'fashion_mnist':
    train_loader, _ = get_fashion_mnist_dataloaders(batch_size=BATCH_SIZE)
    IMG_SIZE = (32, 32, 1)
elif dataset_name == 'lsun':
    train_loader = get_lsun_dataloader(batch_to_data='./data/lsun',
                                       dataset='bedroom_train', batch_size=BATCH_SIZE)
    IMG_SIZE = (128, 128, 3)
else:
    raise ValueError("Unknown dataset_name")

print(f"Dataset: {dataset_name} | IMG_SIZE = {IMG_SIZE} | batches/epoch ≈ {len(train_loader)}")


## Model Definitions

### Generator Network
The generator transforms noise $z \sim \mathcal{N}(0, I)$ into images that should match the real data distribution.

### Discriminator (Critic) Network

The discriminator acts as a **Lipschitz function** that estimates the Wasserstein distance between real and generated distributions.


In [None]:
# === Models ===
class Generator(nn.Module):
    def __init__(self, img_size, latent_dim, dim):
        """
        img_size: (H, W, C)
        """
        super().__init__()
        H, W, C = img_size
        feat_h, feat_w = H // 16, W // 16
        assert H % 16 == 0 and W % 16 == 0, "Image size should be multiples of 16"
        self.dim = dim
        self.latent_dim = latent_dim
        self.img_size = img_size
        self.feat_h, self.feat_w = int(feat_h), int(feat_w)

        self.latent_to_features = nn.Sequential(
            nn.Linear(latent_dim, 8 * dim * self.feat_h * self.feat_w),
            nn.ReLU(inplace=True),
        )
        self.features_to_image = nn.Sequential(
            nn.ConvTranspose2d(8*dim, 4*dim, 4, 2, 1, bias=False), nn.BatchNorm2d(4*dim), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(4*dim, 2*dim, 4, 2, 1, bias=False), nn.BatchNorm2d(2*dim), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(2*dim, 1*dim, 4, 2, 1, bias=False), nn.BatchNorm2d(1*dim), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(1*dim, C, 4, 2, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.latent_to_features(z)
        x = x.view(-1, 8 * self.dim, self.feat_h, self.feat_w)
        return self.features_to_image(x)

    def sample_latent(self, num_samples, device=None):
        device = device or next(self.parameters()).device
        return torch.randn((num_samples, self.latent_dim), device=device)


class Discriminator(nn.Module):
    def __init__(self, img_size, dim):
        super().__init__()
        H, W, C = img_size
        feat_h, feat_w = H // 16, W // 16
        assert H % 16 == 0 and W % 16 == 0
        self.img_size = img_size
        self.feat_h, self.feat_w = int(feat_h), int(feat_w)

        self.image_to_features = nn.Sequential(
            nn.Conv2d(C, dim, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim, 2*dim, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(2*dim, 4*dim, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(4*dim, 8*dim, 4, 2, 1)  # <-- no Sigmoid here
        )
        self.head = nn.Sequential(
            nn.Conv2d(8*dim, 1, 1),        # 1x1 conv to 1 channel
            nn.AdaptiveAvgPool2d(1),       # Global Average Pooling -> (B,1,1,1)
        )

    def forward(self, x):
        h = self.image_to_features(x)
        h = self.head(h).view(x.size(0), 1)  # (B,1)
        return h

In [None]:
# === Training Utils ===
@torch.no_grad()
def preview_batch(loader, n=32):
    x, _ = next(iter(loader))
    grid = make_grid(x[:n], nrow=int(math.sqrt(n)), padding=2)
    plt.figure(figsize=(4,4))
    plt.imshow(np.transpose(grid.numpy(), (1,2,0)), cmap='gray')
    plt.axis('off'); plt.title('Preview batch'); plt.show()

def gradient_penalty(D, real, fake, gp_weight=10.0):
    bsz = real.size(0)
    alpha = torch.rand(bsz, *([1] * (real.dim() - 1)), device=real.device)
    interpolated = alpha * real + (1 - alpha) * fake
    interpolated.requires_grad_(True)

    d_interpolated = D(interpolated)
    grad_outputs = torch.ones_like(d_interpolated, device=real.device)

    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(bsz, -1)
    grad_norm = gradients.norm(2, dim=1)               # (B,)
    gp = gp_weight * ((grad_norm - 1.0) ** 2).mean()
    return gp, grad_norm.mean()

@torch.no_grad()
def grid_from_latents(G, latents, nrow=8):
    imgs = G(latents).clamp(0,1).cpu()
    grid = make_grid(imgs, nrow=nrow, padding=2)
    return np.transpose(grid.numpy(), (1,2,0))


In [None]:
# === Trainer (notebook version) ===
class Trainer:
    def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer,
                 gp_weight=10.0, critic_iterations=5, print_every=50, use_cuda=None):
        self.G = generator
        self.D = discriminator
        self.G_opt = gen_optimizer
        self.D_opt = dis_optimizer

        self.gp_weight = gp_weight
        self.critic_iterations = critic_iterations
        self.print_every = print_every
        self.num_steps = 0

        self.use_cuda = (use_cuda if use_cuda is not None else torch.cuda.is_available())
        if self.use_cuda:
            self.G.cuda()
            self.D.cuda()

        self.losses = {'G': [], 'D': [], 'GP': [], 'grad_norm': []}

    def _critic_train_iteration(self, real_batch):
        self.D_opt.zero_grad(set_to_none=True)
        bsz = real_batch.size(0)

        z = self.G.sample_latent(bsz, device=real_batch.device)
        fake = self.G(z)

        d_real = self.D(real_batch)
        d_fake = self.D(fake.detach())

        gp, gnorm = gradient_penalty(self.D, real_batch, fake.detach(), gp_weight=self.gp_weight)

        d_loss = d_fake.mean() - d_real.mean() + gp
        d_loss.backward()
        self.D_opt.step()

        self.losses['D'].append(d_loss.item())
        self.losses['GP'].append(gp.item())
        self.losses['grad_norm'].append(gnorm.item())

    def _generator_train_iteration(self, real_batch):
        self.G_opt.zero_grad(set_to_none=True)
        bsz = real_batch.size(0)
        z = self.G.sample_latent(bsz, device=real_batch.device)
        fake = self.G(z)
        g_loss = - self.D(fake).mean()
        g_loss.backward()
        self.G_opt.step()
        self.losses['G'].append(g_loss.item())

    def train(self, data_loader, epochs=200, save_training_gif=True, gif_every_epoch=True, gif_name='training.gif'):
        if save_training_gif:
            fixed_latents = self.G.sample_latent(64, device=device)
            progress_frames = []

        for epoch in range(1, epochs + 1):
            pbar = tqdm(data_loader, desc=f"[Epoch {epoch}/{epochs}]", leave=False)
            for i, (data, *_) in enumerate(pbar):
                data = data.to(device, non_blocking=True)
                self.num_steps += 1

                self._critic_train_iteration(data)

                if self.num_steps % self.critic_iterations == 0:
                    self._generator_train_iteration(data)

                if (i % self.print_every == 0) and len(self.losses['G']) > 0:
                    pbar.set_postfix(D=f"{self.losses['D'][-1]:.3f}",
                                     GP=f"{self.losses['GP'][-1]:.3f}",
                                     G=f"{self.losses['G'][-1]:.3f}",
                                     gN=f"{self.losses['grad_norm'][-1]:.2f}")

            if save_training_gif and (gif_every_epoch or epoch == epochs):
                frame = grid_from_latents(self.G, fixed_latents, nrow=8)
                progress_frames.append((frame * 255).astype(np.uint8))

                clear_output(wait=True)
                fig, ax = plt.subplots(figsize=(4,4))
                ax.imshow(frame, cmap='gray')
                ax.axis('off')
                ax.set_title(f"Epoch {epoch}")
                display(fig)
                plt.close(fig)    

In [None]:
# === Quick Preview ===
preview_batch(train_loader, n=64)

In [None]:
# === Instantiate models & optimizers ===
LATENT_DIM = 100
DIM        = 16
LR         = 1e-4
BETAS      = (0.9, 0.99)
EPOCHS     = 200
CRITIC_ITERS = 5
GP_WEIGHT    = 10.0
PRINT_EVERY  = 50

G = Generator(img_size=IMG_SIZE, latent_dim=LATENT_DIM, dim=DIM).to(device)
D = Discriminator(img_size=IMG_SIZE, dim=DIM).to(device)

G_opt = torch.optim.Adam(G.parameters(), lr=LR, betas=BETAS)
D_opt = torch.optim.Adam(D.parameters(), lr=LR, betas=BETAS)

sum_params = lambda m: sum(p.numel() for p in m.parameters())
print(G)
print(D)
print(f"G params: {sum_params(G):,} | D params: {sum_params(D):,}")


In [None]:
# === Train ===
trainer = Trainer(G, D, G_opt, D_opt,
                  gp_weight=GP_WEIGHT,
                  critic_iterations=CRITIC_ITERS,
                  print_every=PRINT_EVERY,
                  use_cuda=(device.type=='cuda'))

trainer.train(train_loader, epochs=EPOCHS,
              save_training_gif=True,
              gif_every_epoch=True,
              gif_name=f'training_{dataset_name}_{EPOCHS}e.gif')

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(trainer.losses['D'], label='D')
plt.plot(trainer.losses['GP'], label='GP')
plt.legend(); plt.title("Discriminator / GP")

plt.subplot(1,2,2)
plt.plot(trainer.losses['G'], label='G', color='orange')
plt.legend(); plt.title("Generator")
plt.tight_layout(); plt.show()
