# Generative Models Part 1: An Introduction to GANs
This notebook trains **two** generative models on MNIST and saves a *generated* grid (16 x 8), it is meant to reproduce what you see in the article :

- **Vanilla GAN** (logistic discriminator, non-saturating generator loss)
- **Vanilla SGD** (logistic discriminator, non-saturating generator loss)
- **WGAN-GP** (critic + gradient penalty)

## Imports & Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.utils.data as data

import matplotlib.pyplot as plt
import os
import numpy as np
import time

from IPython.display import display

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device is', device)

# Displaying function 
def imshow(img, size=None):
    img = img*0.5 + 0.5  # unnormalize
    if size is not None:
        img = transforms.Resize(size=size, interpolation=transforms.InterpolationMode.NEAREST, antialias=True)(img)
    pil_img = torchvision.transforms.functional.to_pil_image(img)
    display(pil_img)
    return None


##  Fetch the DATA (MNIST normalized to [-1, 1])

In [None]:
batch_size = 128
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_set = MNIST(os.getcwd(), train=True, transform=transform, download=True)
train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)

# Visual sanity check
real, _ = next(iter(train_loader))
print(real.shape)
imshow(torchvision.utils.make_grid(real.to('cpu'), nrow=16))
print("Sample of the data that we want to generate")


##  Class of the Models

In [None]:
# Size of generator input
nz = 100

# Size of feature maps in generator and discriminator
ngf, ndf = 64, 64

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels=nz, out_channels=ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(in_channels=ngf * 8, out_channels=ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(in_channels=ngf * 4, out_channels=ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(in_channels=ngf * 2, out_channels=ngf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(in_channels=ngf, out_channels=1, kernel_size=1, stride=1, padding=2, bias=False),
            nn.Tanh()
            # output size. 1 x 28 x 28
        )

    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # input is 1 x 28 x 28
            nn.Conv2d(in_channels=1, out_channels=ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 15 x 15
            nn.Conv2d(in_channels=ndf, out_channels=ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 8 x 8
            nn.Conv2d(in_channels=ndf * 2, out_channels=ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 5 x 5
            nn.Conv2d(in_channels=ndf * 4, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False)
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


##  Utilities 

In [None]:
# function to display samples of the generator (reused)
def show(G, z=None, batch_size=128, nz=100):
    with torch.no_grad():
        if z is None:
            z = torch.randn(batch_size, nz, 1, 1).to(device)
        genimages = G(z)
        imshow(torchvision.utils.make_grid(genimages.to('cpu'), nrow=16))
        return None


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


## Vanilla GAN Training (Adam)

DCGAN-style architecture unchanged. More specifically, we use the **vanilla GAN** objective with the **non-saturating generator loss**.


In [None]:
torch.manual_seed(1)

num_epochs = 20        # set to 100 to match the article snapshots
log_every = 200        # batches
lr = 2e-4

G_adam = Generator().to(device)
D_adam = Discriminator().to(device)
G_adam.apply(weights_init);
D_adam.apply(weights_init);

optimD = optim.Adam(D_adam.parameters(), lr=lr, betas=(0.5, 0.999))
optimG = optim.Adam(G_adam.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCEWithLogitsLoss()

fixed_z = torch.randn(batch_size, nz, 1, 1, device=device)

D_losses_adam, G_losses_adam = [], []

t0 = time.time()
for epoch in range(num_epochs):
    for i, (real, _) in enumerate(train_loader):
        real = real.to(device)

        # --------------------
        # (1) Update Discriminator: maximize log D(x) + log(1 - D(G(z)))
        # --------------------
        optimD.zero_grad(set_to_none=True)

        # Real
        logits_real = D_adam(real)
        labels_real = torch.ones_like(logits_real, device=device)
        lossD_real = criterion(logits_real, labels_real)

        # Fake
        z = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = G_adam(z)
        logits_fake = D_adam(fake.detach())
        labels_fake = torch.zeros_like(logits_fake, device=device)
        lossD_fake = criterion(logits_fake, labels_fake)

        lossD = lossD_real + lossD_fake
        lossD.backward()
        optimD.step()

        # --------------------
        # (2) Update Generator: maximize log D(G(z)) (non-saturating)
        # --------------------
        optimG.zero_grad(set_to_none=True)
        z2 = torch.randn(batch_size, nz, 1, 1, device=device)
        fake2 = G_adam(z2)
        logits_fake2 = D_adam(fake2)
        labels_gen = torch.ones_like(logits_fake2, device=device)
        lossG = criterion(logits_fake2, labels_gen)
        lossG.backward()
        optimG.step()

        # Logging
        if i % log_every == 0:
            D_losses_adam.append(lossD.item())
            G_losses_adam.append(lossG.item())
            print(f"[Adam] epoch {epoch:03d} | iter {i:04d} | lossD {lossD.item():.3f} | lossG {lossG.item():.3f} | t={(time.time()-t0)/60:.1f}m")

    # End of epoch snapshot
    print(f"[Adam] Snapshot at epoch {epoch+1}")
    G_adam.eval()
    show(G_adam, z=fixed_z, batch_size=batch_size, nz=nz)
    G_adam.train()


In [None]:
# Loss curves (Adam)
plt.figure()
plt.plot(D_losses_adam, label='D (Adam)')
plt.plot(G_losses_adam, label='G (Adam)')
plt.legend()
plt.title('Vanilla GAN (Adam) — losses (logged every few batches)')
plt.xlabel('log step')
plt.show()


##  Vanilla GAN Training (SGD)

Same objective as above, but we replace Adam by **plain SGD**. On MNIST with this setup, training is typically much less stable and can exhibit **mode collapse**, as discussed in the article.

In [None]:
torch.manual_seed(1)

num_epochs = 20        # set to 100 to reproduce longer-run collapse
log_every = 200

# SGD hyperparameters: feel free to tune to observe collapse more clearly
lr_sgd = 2e-4
momentum = 0.0 # changing momentum to 0.9 can help show mode collapse

# if you aren't getting any luck, try lowering the batch size (to 16 for example)

G_sgd = Generator().to(device)
D_sgd = Discriminator().to(device)
G_sgd.apply(weights_init);
D_sgd.apply(weights_init);

optimD = optim.SGD(D_sgd.parameters(), lr=lr_sgd, momentum=momentum)
optimG = optim.SGD(G_sgd.parameters(), lr=lr_sgd, momentum=momentum)

criterion = nn.BCEWithLogitsLoss()

fixed_z = torch.randn(batch_size, nz, 1, 1, device=device)

D_losses_sgd, G_losses_sgd = [], []

t0 = time.time()
for epoch in range(num_epochs):
    for i, (real, _) in enumerate(train_loader):
        real = real.to(device)

        # (1) Discriminator
        optimD.zero_grad(set_to_none=True)

        logits_real = D_sgd(real)
        labels_real = torch.ones_like(logits_real, device=device)
        lossD_real = criterion(logits_real, labels_real)

        z = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = G_sgd(z)
        logits_fake = D_sgd(fake.detach())
        labels_fake = torch.zeros_like(logits_fake, device=device)
        lossD_fake = criterion(logits_fake, labels_fake)

        lossD = lossD_real + lossD_fake
        lossD.backward()
        optimD.step()

        # (2) Generator (non-saturating)
        optimG.zero_grad(set_to_none=True)
        z2 = torch.randn(batch_size, nz, 1, 1, device=device)
        fake2 = G_sgd(z2)
        logits_fake2 = D_sgd(fake2)
        labels_gen = torch.ones_like(logits_fake2, device=device)
        lossG = criterion(logits_fake2, labels_gen)
        lossG.backward()
        optimG.step()

        if i % log_every == 0:
            D_losses_sgd.append(lossD.item())
            G_losses_sgd.append(lossG.item())
            print(f"[SGD] epoch {epoch:03d} | iter {i:04d} | lossD {lossD.item():.3f} | lossG {lossG.item():.3f} | t={(time.time()-t0)/60:.1f}m")

    print(f"[SGD] Snapshot at epoch {epoch+1}")
    G_sgd.eval()
    show(G_sgd, z=fixed_z, batch_size=batch_size, nz=nz)
    G_sgd.train()


In [None]:
# Loss curves (SGD)
plt.figure()
plt.plot(D_losses_sgd, label='D (SGD)')
plt.plot(G_losses_sgd, label='G (SGD)')
plt.legend()
plt.title('Vanilla GAN (SGD) — losses (logged every few batches)')
plt.xlabel('log step')
plt.show()


## WGAN-GP

We now switch to the WGAN objective (critic outputs real values) and enforce the 1-Lipschitz constraint using the **gradient penalty**:

$GP = \mathbb{E}[(\|
\nabla_{\hat x} D(\hat x)\|_2 - 1)^2] $

We reuse the same `Discriminator` class, but interpret it as a **critic**.

In [None]:
# Gradient penalty 
def gradient_penalty(D, x_real, x_fake):
    alpha = torch.rand((x_real.shape[0], 1, 1, 1), device=device)
    x_hat = alpha * x_real + (1 - alpha) * x_fake
    x_hat.requires_grad_(True)

    d_hat = D(x_hat)
    grad_outputs = torch.ones_like(d_hat, device=device)

    gradients = torch.autograd.grad(
        outputs=d_hat,
        inputs=x_hat,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    grad_norm = gradients.norm(2, dim=1)
    gp = ((grad_norm - 1) ** 2).mean()
    return gp


## WGAN-GP Training

In [None]:
torch.manual_seed(1)

num_epochs = 20       # set to 100 for long training
log_every = 200

# WGAN-GP typical settings
lr = 2e-4
betas = (0.5, 0.999)
n_critic = 5
lambda_gp = 10.0

G_wgan = Generator().to(device)
D_wgan = Discriminator().to(device)
G_wgan.apply(weights_init);
D_wgan.apply(weights_init);

optimD = optim.Adam(D_wgan.parameters(), lr=lr, betas=betas)
optimG = optim.Adam(G_wgan.parameters(), lr=lr, betas=betas)

fixed_z = torch.randn(batch_size, nz, 1, 1, device=device)

D_losses_wgan, G_losses_wgan = [], []

t0 = time.time()
for epoch in range(num_epochs):
    for i, (real, _) in enumerate(train_loader):
        real = real.to(device)

        # --------------------
        # (1) Critic updates
        # --------------------
        for _ in range(n_critic):
            optimD.zero_grad(set_to_none=True)

            z = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = G_wgan(z).detach()

            d_real = D_wgan(real).mean()
            d_fake = D_wgan(fake).mean()
            gp = gradient_penalty(D_wgan, real, fake)

            # We *minimize* the negative of the WGAN objective
            lossD = -(d_real - d_fake) + lambda_gp * gp
            lossD.backward()
            optimD.step()

        # --------------------
        # (2) Generator update
        # --------------------
        optimG.zero_grad(set_to_none=True)
        z2 = torch.randn(batch_size, nz, 1, 1, device=device)
        fake2 = G_wgan(z2)
        lossG = -D_wgan(fake2).mean()
        lossG.backward()
        optimG.step()

        if i % log_every == 0:
            D_losses_wgan.append(lossD.item())
            G_losses_wgan.append(lossG.item())
            print(f"[WGAN-GP] epoch {epoch:03d} | iter {i:04d} | lossD {lossD.item():.3f} | lossG {lossG.item():.3f} | t={(time.time()-t0)/60:.1f}m")

    print(f"[WGAN-GP] Snapshot at epoch {epoch+1}")
    G_wgan.eval()
    show(G_wgan, z=fixed_z, batch_size=batch_size, nz=nz)
    G_wgan.train()


In [None]:
# Loss curves (WGAN-GP)
plt.figure()
plt.plot(D_losses_wgan, label='Critic (WGAN-GP)')
plt.plot(G_losses_wgan, label='G (WGAN-GP)')
plt.legend()
plt.title('WGAN-GP — losses (logged every few batches)')
plt.xlabel('log step')
plt.show()


## Quick qualitative comparison

We sample 128 images from each trained generator (same fixed noise), display them as grids and compare the results.

In [None]:
# Put generators in eval mode for clean BatchNorm behavior
G_adam.eval(); G_sgd.eval(); G_wgan.eval();

with torch.no_grad():
    z = torch.randn(batch_size, nz, 1, 1, device=device)
    xa = G_adam(z).cpu()
    xs = G_sgd(z).cpu()
    xw = G_wgan(z).cpu()

# Build grids
grid_a = torchvision.utils.make_grid(xa, nrow=16, normalize=False)
grid_s = torchvision.utils.make_grid(xs, nrow=16, normalize=False)
grid_w = torchvision.utils.make_grid(xw, nrow=16, normalize=False)

# Display sequentially 
print('Vanilla GAN (Adam)')
imshow(grid_a)

print('Vanilla GAN (SGD)')
imshow(grid_s)

print('WGAN-GP')
imshow(grid_w)
