<a href="https://colab.research.google.com/github/AmjadNasser1/advanced_ai_exercises/blob/main/DEBUGGED_lab2_vanilla_gan_debugging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lab 2: Debug a Broken Vanilla GAN (find 12+ issues)

In [None]:
import torch, torchvision, torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform=transforms.Compose([transforms.ToTensor()])

# BUG:MNIST images needs normalization.
# FIX:transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

loader=DataLoader(torchvision.datasets.MNIST('./data',True,download=True,transform=transform),batch_size=256,shuffle=True,num_workers=2,pin_memory=True)
z_dim=100; g_lr=2e-2; d_lr=2e-5
class D(nn.Module):
    def __init__(self):
        super().__init__(); self.net=nn.Sequential(nn.Conv2d(1,32,4,2,1),nn.LeakyReLU(0.2,True),nn.Conv2d(32,64,4,2,1),nn.BatchNorm2d(64),nn.LeakyReLU(0.2,True),nn.Conv2d(64,1,7,1,0),nn.Sigmoid())
    def forward(self,x):return self.net(x).view(x.size(0),1)
class G(nn.Module):
    def __init__(self):
        super().__init__(); self.net=nn.Sequential(nn.ConvTranspose2d(100,128,4,1,0),nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128,64,4,2,1),nn.BatchNorm2d(64),nn.ReLU(True),nn.ConvTranspose2d(64,1,4,2,1))

        # BUG:Missing Tanh (DCGAN expect data in [-1,1])
        # FIX:super().__init__(); self.net=nn.Sequential(nn.ConvTranspose2d(100,128,4,1,0),nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128,64,4,2,1),nn.BatchNorm2d(64),nn.ReLU(True),nn.ConvTranspose2d(64,1,4,2,1), nn.Tanh())

    def forward(self,z):
        z=z.view(z.size(0),64,1,1); return self.net(z)
        # BUG:Incorrect reshape, it reshapes z to 64 channels, which mismatches the first layer that expects 100 channels should have z_dim channels, not 64.
        # FIX:FIX:z=z.view(z.size(0), z_dim, 1, 1); return self.net(z)

Dnet=D().to(device); Gnet=G().to(device)
crit=nn.BCEWithLogitsLoss(); opt_d=torch.optim.Adam(Dnet.parameters(),lr=d_lr,betas=(0.9,0.999)); opt_g=torch.optim.Adam(Gnet.parameters(),lr=g_lr,betas=(0.9,0.999))

# BUG: Using BCEWithLogitsLoss, while the Discriminator output layer uses Sigmoid. Need to switch the loss to BCELoss
# FIX: criterion = nn.BCELoss() instead of BCEWithLogitsLoss

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 485kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.47MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.30MB/s]


In [None]:
from tqdm import tqdm
for real,_ in tqdm(loader):

    # BUG: Single pass no epochs
    # FIX: for epoch in range(EPOCHS):

    real=real.to(device); b=real.size(0)
    z=torch.randn(b,z_dim,device=device);
    fake=Gnet(z.view(b,z_dim,1,1))
    loss_d=crit(Dnet(real),torch.zeros(b,1,device=device))+crit(Dnet(fake),torch.ones(b,1,device=device))

    # BUG: Labels are flipped for discriminator loss. Real images should have label 1  and fake images should have label 0 when using BCEWithLogitsLoss.
    # FIX: loss_d=crit(Dnet(real),torch.ones(b,1,device=device))+crit(Dnet(fake),torch.zeros(b,1,device=device))

    loss_d.backward();
    opt_g.step()

    # BUG: Stepping the wrong optimizer
    # FIX: should call opt_d.step().

    # BUG: Never reseting grads
    # FIX: call opt_d.zero_grad() before D backward and opt_g.zero_grad() before G backward.

    z=torch.randn(b,z_dim,device=device);
    fake=Gnet(z.view(b,z_dim,1,1))

    #BUG: No detach, without it gradients flow into G unintentionally
    #FIX: fake = Gnet(z.view(b, z_dim, 1, 1)).detach()

    loss_g=crit(Dnet(fake),torch.zeros(b,1,device=device))

    #BUG: G should aim for 1 not 0
    #FIX: loss_g = crit(out_fake_for_g, torch.ones(b,1,device=device))

    loss_g.backward()
    # BUG: Zeroing The generator grads before backpropagtion + stepping the generator optimizer step is missing.
    # FIX: Add opt_g.zero_grad(), opt_g.step().


print('Now fix all the issues.')

  0%|          | 0/235 [00:00<?, ?it/s]


RuntimeError: shape '[256, 64, 1, 1]' is invalid for input of size 25600