<a href="https://colab.research.google.com/github/AroubAlRizq/advanced_ai_exercises/blob/main/WGAN_GP_Lab2_Solution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# ---------------------------
# Data  (CIFAR-10 32×32)
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])  # OK for Tanh - why? because it normalizes the range to [-1,1]
])
ds = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

100%|██████████| 170M/170M [00:04<00:00, 37.7MB/s]


In [4]:
# ---------------------------
# Hyperparams
# ---------------------------
z_dim = 128
g_lr  = 2e-4
d_lr  = 2e-4
n_critic = 5                      # BUG - critic should be 5 not 1
lambda_gp = 10.0

In [5]:
# ---------------------------
# Models
# ---------------------------
class Critic(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch,   4, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch, ch*2, 4, 2, 1),
            nn.InstanceNorm2d(ch*2, affine=True),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*2, ch*4, 4, 2, 1),
            nn.InstanceNorm2d(ch*4, affine=True),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*4, 1, 4, 1, 0),
            #nn.Sigmoid()                 # BUG - Remove Sigmoid
        )
    def forward(self, x):
      return self.net(x).view(x.size(0))

class Gen(nn.Module):
    def __init__(self, z=128, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z,   ch*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ch*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ch*4, ch*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ch*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ch*2, ch,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ch),
            nn.ReLU(True),
            nn.ConvTranspose2d(ch,   3,    4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, z): return self.net(z.view(z.size(0), z.size(1), 1, 1))

D = Critic().to(device)
G = Gen(z_dim).to(device)

In [8]:
# ---------------------------
# Setup
# ---------------------------
D.train()
G.train()

lambda_gp = 10.0         # gradient penalty coefficient (standard)
n_critic  = 5            # critic steps per generator step (standard)
g_lr = d_lr = 5e-5

optG = torch.optim.RMSprop(G.parameters(), lr=g_lr)  # WGAN (vanilla) uses RMSprop
optD = torch.optim.RMSprop(D.parameters(), lr=d_lr)

def sample_z(batch, z_dim, device):
    return torch.randn(batch, z_dim, 1, 1, device=device)  # changed to zdim for consistency

# ---------------------------
# Losses (WGAN)
# ---------------------------
def d_loss(real_scores, fake_scores):
    return fake_scores.mean() - real_scores.mean()

def g_loss(fake_scores):
    return -fake_scores.mean()

# ---------------------------
# Gradient Penalty (WGAN-GP)
# ---------------------------
def gradient_penalty(Dnet, real, fake, lambda_gp=10.0):
    b = real.size(0)
    eps = torch.rand(b, 1, 1, 1, device=real.device, dtype=real.dtype) # added the ones
    x_hat = eps * real + (1.0 - eps) * fake
    x_hat.requires_grad_(True)

    d_hat = Dnet(x_hat)
    d_hat = d_hat.view(b)

    grads = torch.autograd.grad(
        outputs=d_hat.sum(),                # scalar
        inputs=x_hat,
        create_graph=True,                  # retain graph for higher-order grad
        retain_graph=True,
        only_inputs=True
    )[0]                                    # [B, C, H, W]

    grad_norm = grads.view(b, -1).norm(2, dim=1)
    gp = lambda_gp * ((grad_norm - 1.0) ** 2).mean()
    return gp

# ---------------------------
# Training loop (fixed)
# ---------------------------
for step, (real, _) in enumerate(loader):
    real = real.to(device)
    b = real.size(0)

    # ---- Critic updates ----
    for _ in range(n_critic):
        z = sample_z(b, z_dim, device)
        with torch.no_grad():
            fake = G(z)

        # (2) Critic scores
        real_scores = D(real).view(b)
        fake_scores = D(fake).view(b)

        # (3) WGAN loss + GP
        gp = gradient_penalty(D, real, fake, lambda_gp=lambda_gp)
        lossD = d_loss(real_scores, fake_scores) + gp

        # (4) Optimize critic
        optD.zero_grad(set_to_none=True)
        lossD.backward()
        optD.step()

    # ---- Generator update ----
    z = sample_z(b, z_dim, device)
    fake = G(z)
    fake_scores = D(fake).view(b)
    lossG = g_loss(fake_scores)

    optG.zero_grad(set_to_none=True)
    lossG.backward()
    optG.step()

    # (Optional) short-run break for debugging
    # if step > 10: break
