<a href="https://colab.research.google.com/github/WedyanFawaz/advanced_ai_exercises/blob/main/lab2_wgan_gp_bug_fix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WGAN-GP — Bug-Fix Labs - 10 Bugs to Fix

In [1]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (CIFAR-10 32×32)
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])  # OK for Tanh
])

In [3]:
ds = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Hyperparams
# ---------------------------
z_dim = 128
g_lr  = 2e-4
d_lr  = 2e-4
n_critic = 5                      # BUG
lambda_gp = 10.0

100%|██████████| 170M/170M [00:01<00:00, 91.9MB/s]


In [4]:
# ---------------------------
# Models
# ---------------------------
class Critic(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch,   4, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch, ch*2, 4, 2, 1),
            nn.InstanceNorm2d(ch*2, affine=True),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*2, ch*4, 4, 2, 1),
            nn.InstanceNorm2d(ch*4, affine=True),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch*4, 1, 4, 1, 0),

        )
    def forward(self, x): return self.net(x).view(x.size(0))

class Gen(nn.Module):
    def __init__(self, z=128, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z,   ch*4, 4, 1, 0, bias=False), nn.BatchNorm2d(ch*4), nn.ReLU(True),
            nn.ConvTranspose2d(ch*4, ch*2, 4, 2, 1, bias=False), nn.BatchNorm2d(ch*2), nn.ReLU(True),
            nn.ConvTranspose2d(ch*2, ch,   4, 2, 1, bias=False), nn.BatchNorm2d(ch),   nn.ReLU(True),
            nn.ConvTranspose2d(ch,   3,    4, 2, 1, bias=False), nn.Tanh()
        )
    def forward(self, z): return self.net(z.view(z.size(0), z.size(1), 1, 1))

D = Critic().to(device)
G = Gen(z_dim).to(device)

In [8]:
# ---------------------------
# Loss placeholders & Optimizers (WRONG for WGAN… intentionally)
# ---------------------------
bce = nn.BCEWithLogitsLoss()      # BUG

def d_c(d_real, d_fake):
  return torch.mean(d_fake) - torch.mean(d_real)

def g_c(d_fake):
  return -torch.mean(d_fake)

optG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.0, 0.9))   #change it to 0.0 to 0.9
optD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.0, 0.9))   # BUG


In [9]:
# ---------------------------
# Broken gradient penalty
# ---------------------------
def gradient_penalty(Dnet, real, fake):
    b = real.size(0)
    eps = torch.randn(b,1,1,1, device=real.device)             # dim
    x_hat = eps*real + (1-eps)*fake
    x_hat.requires_grad_(True)
    d_hat = Dnet(x_hat)
    grads = torch.autograd.grad(d_hat.sum(), x_hat, retain_graph=True)[0]    # real to x_hat
    gp = lambda_gp * (grads.view(b, -1).norm(dim=1) - 1.0).mean()            # sum to norm
    return gp


In [11]:

# ---------------------------
# Training loop (intentionally wrong)
# ---------------------------
for step, (real, _) in enumerate(loader):
    real = real.to(device)
    b = real.size(0)

    # -- Critic updates --
    for _ in range(n_critic):                    # BUG
        z = torch.randn(b, z_dim, device=device)
        fake = G(z).detach()                              # BUG

        # WRONG
        d_real = D(real)
        d_fake = D(fake)

        lossD = torch.mean(d_fake) - torch.mean(d_real) +\
                gradient_penalty(D, real, fake)



        optD.zero_grad()
        lossD.backward()
        optD.step()                              # BUG
        # (also missing)

        # BUG
        # for p in D.parameters():
        #     p.data.clamp_(-0.01, 0.01)

    # -- Generator update --
    z = torch.randn(b, z_dim, device=device)
    fake = G(z)
    # WRONG
    lossG = -torch.mean(D(fake))   # BUG\
    optG.zero_grad()
    lossG.backward()
    optG.step()                              # BUG

    if step > 10:   # keep the broken demo short
        break

print("Your task: fix all bugs until the WGAN-GP training runs stably.")




Your task: fix all bugs until the WGAN-GP training runs stably.
