<table style="background-color:#FFFFFF">   
  <tr>     
  <td><img src="https://upload.wikimedia.org/wikipedia/commons/9/95/Logo_EPFL_2019.svg" width="150x"/>
  </td>     
  <td>
  <h1> <b>CS-461: Foundation Models and Generative AI</b> </h1>
  Prof. Charlotte Bunne  
  </td>   
  </tr>
</table>

# ðŸ“š  Exercise Session 3 - Code Demonstration: GANs

In this notebook, we demonstrate a simple implementation of a Generative Adversarial Network (GAN) using the MNIST dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Hyperparameters & Data Preparation

In [None]:
# Hyperparameters
batch_size = 64
lr = 0.0002
z_dim = 100
epochs = 10
img_size = 28
channels = 1

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

### Generator and Discriminator Architectures

In [None]:
# Generator with CNN (Transposed Convolutions)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 128*7*7),
            nn.ReLU(True),
            nn.Unflatten(1, (128, 7, 7)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output in [-1,1]
        )
    def forward(self, z):
        return self.model(z)

# Discriminator with CNN
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128*7*7, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

In [None]:
# Initialize models
G = Generator().to(device)
D = Discriminator().to(device)

### GAN Training
Here, we train our generator and discrimator by alternately optimization of the Original GAN min-max objective

\begin{equation}
\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}(\mathbf{x})} [\log D(\mathbf{x})] 
                        + \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\mathbf{z})} [\log (1 - D(G(\mathbf{z})))]
\end{equation}

This leads to the discriminator loss
\begin{equation}
\mathcal{L}_D = - \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}(\mathbf{x})} \big[\log D(\mathbf{x})\big] 
                 - \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\mathbf{z})} \big[\log (1 - D(G(\mathbf{z})))\big],
\end{equation}

and the generator loss

\begin{equation}
\mathcal{L}_G = - \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\mathbf{z})} \big[\log D(G(\mathbf{z}))\big].
\end{equation}


In [None]:
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# Training loop
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        imgs = imgs.to(device)
        real_labels = torch.ones(imgs.size(0), 1).to(device)
        fake_labels = torch.zeros(imgs.size(0), 1).to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        outputs_real = D(imgs)
        loss_real = criterion(outputs_real, real_labels)

        z = torch.randn(imgs.size(0), z_dim).to(device)
        fake_imgs = G(z)
        outputs_fake = D(fake_imgs.detach())
        loss_fake = criterion(outputs_fake, fake_labels)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        outputs = D(fake_imgs)
        loss_G = criterion(outputs, real_labels)
        loss_G.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/{epochs}]  Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

In [None]:
# Generate sample images
z = torch.randn(25, z_dim).to(device)
fake_imgs = G(z).cpu().detach()
fake_imgs = (fake_imgs + 1) / 2  # Rescale to [0,1]

fig, axes = plt.subplots(5, 5, figsize=(6,6))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(fake_imgs[i].squeeze(), cmap='gray')
    ax.axis('off')
plt.show()