<a href="https://colab.research.google.com/github/WedyanFawaz/advanced_ai_exercises/blob/main/lab1_vanilla_gan_skeleton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lab 1: Vanilla GAN (Goodfellow) — Skeleton

In [1]:
import torch, torchvision, torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
import math

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size=128; z_dim=64; g_lr=d_lr=2e-4; num_epochs=3
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_ds=torchvision.datasets.MNIST('./data',True,download=True,transform=transform)
train_loader=DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=2,pin_memory=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 36.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.21MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 10.1MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.99MB/s]


In [30]:
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 128, 7, 1, 0),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128*7*7, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

In [31]:
criterion=nn.BCEWithLogitsLoss()
G=Generator(z_dim).to(device); D=Discriminator().to(device)
opt_g=torch.optim.Adam(G.parameters(),lr=g_lr,betas=(0.5,0.999))
opt_d=torch.optim.Adam(D.parameters(),lr=d_lr,betas=(0.5,0.999))

In [32]:
@torch.no_grad()
def sample_grid(G,n=16):
    z=torch.randn(n,64,device=device).view(n,64,1,1)
    fake=G(z).cpu()
    return make_grid(fake,nrow=int(math.sqrt(n)),normalize=True,value_range=(-1,1))

In [33]:
def train_discriminator_step(real):
    # TODO: Goodfellow D loss with detach on fake
    b=real.size(0)
    opt_d.zero_grad()
    z=torch.randn(b,z_dim,device=device);
    fake=G(z)
    loss_d=criterion(D(real),torch.ones(b,1,device=device))+criterion(D(fake.detach()),torch.zeros(b,1,device=device)) # Corrected labels and detached fake
    loss_d.backward(); opt_d.step() # Stepping Discriminator optimizer

def train_generator_step(b):
    opt_g.zero_grad()
    z=torch.randn(b,z_dim,device=device); fake=G(z)
    loss_g=criterion(D(fake),torch.ones(b,1,device=device)) # Corrected to use non-saturating loss
    loss_g.backward(); opt_g.step() # Stepping Generator optimizer

In [None]:
from tqdm import tqdm
for epoch in range(num_epochs):
    pbar=tqdm(train_loader)
    for real,_ in pbar:
        real=real.to(device)
        d=train_discriminator_step(real)
        g=train_generator_step(real.size(0))
        pbar.set_postfix(d_loss=d,g_loss=g)
print('Done')

100%|██████████| 469/469 [06:23<00:00,  1.22it/s, d_loss=None, g_loss=None]
 92%|█████████▏| 430/469 [05:53<00:37,  1.04it/s, d_loss=None, g_loss=None]