# Chapter 3: Generative Adversarial Networks (GANs)

## Introduction

Generative Adversarial Networks (GANs) are one of the most exciting developments in generative AI. Introduced by Ian Goodfellow in 2014, GANs use an adversarial training process where two neural networks compete against each other:

- **Generator (G)**: Creates fake data that tries to fool the discriminator
- **Discriminator (D)**: Distinguishes between real and fake data

This adversarial process leads to increasingly realistic generated samples.


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


## 3.1 GAN Theory and Architecture

### The GAN Game Theory

GANs can be understood as a minimax game between two players:

$$\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$$

Where:
- $G$ is the generator that maps noise $z$ to fake data
- $D$ is the discriminator that outputs the probability that input is real
- $p_{data}$ is the real data distribution
- $p_z$ is the noise distribution (usually Gaussian)

### Training Process

1. **Train Discriminator**: Maximize ability to distinguish real from fake
2. **Train Generator**: Minimize discriminator's ability to detect fakes
3. **Alternate**: Between generator and discriminator training


In [None]:
# Define the Generator Network
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=1):
        """
        nz: size of latent vector
        ngf: generator feature map size
        nc: number of output channels
        """
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size: (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size: (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size: (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size: (nc) x 64 x 64
        )
    
    def forward(self, input):
        return self.main(input)


In [None]:
# Define the Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, nc=1, ndf=64):
        """
        nc: number of input channels
        ndf: discriminator feature map size
        """
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size: (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # output size: 1 x 1 x 1
        )
    
    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


In [None]:
# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

# Load MNIST dataset
dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

# Create data loader
batch_size = 128
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

print(f"Dataset size: {len(dataset)}")
print(f"Number of batches: {len(dataloader)}")

# Visualize some real images
def show_images(images, title="Images"):
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i, ax in enumerate(axes.flat):
        if i < len(images):
            # Denormalize for visualization
            img = images[i].squeeze() * 0.5 + 0.5
            ax.imshow(img, cmap='gray')
            ax.axis('off')
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Show some real images
real_batch = next(iter(dataloader))
show_images(real_batch[0][:8], "Real MNIST Images")


## 3.2 Training Setup and Hyperparameters

Key training considerations for GANs:

1. **Learning Rates**: Often different for G and D
2. **Loss Functions**: Binary cross-entropy for standard GAN
3. **Initialization**: Proper weight initialization is crucial
4. **Training Balance**: Ensuring neither network dominates


In [None]:
# Initialize networks
nz = 100  # Size of latent vector
ngf = 64  # Generator feature map size
ndf = 64  # Discriminator feature map size
nc = 1    # Number of channels (grayscale)

# Create networks
netG = Generator(nz, ngf, nc).to(device)
netD = Discriminator(nc, ndf).to(device)

# Weight initialization function
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Apply weight initialization
netG.apply(weights_init)
netD.apply(weights_init)

print("Generator Architecture:")
print(netG)
print("\nDiscriminator Architecture:")
print(netD)


In [None]:
# Training setup
criterion = nn.BCELoss()
lr = 0.0002
beta1 = 0.5

# Optimizers
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Labels
real_label = 1.
fake_label = 0.

# Create batch of latent vectors for visualization
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Training loop
num_epochs = 25
G_losses = []
D_losses = []

print("Starting Training...")

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        real_batch = data[0].to(device)
        b_size = real_batch.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        
        output = netD(real_batch)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                  f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                  f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

    # Check how the generator is doing by saving G's output on fixed_noise
    if (epoch % 5 == 0) or (epoch == num_epochs-1):
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
            show_images(fake[:8], f"Generated Images - Epoch {epoch}")

print("Training completed!")


In [None]:
# Plot training losses
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Generate new samples
with torch.no_grad():
    noise = torch.randn(16, nz, 1, 1, device=device)
    generated_images = netG(noise).cpu()
    show_images(generated_images, "Final Generated Samples")


## 3.3 Common GAN Problems and Solutions

### Mode Collapse
- **Problem**: Generator produces limited variety of samples
- **Solutions**: Unrolled GANs, Feature matching, Minibatch discrimination

### Training Instability
- **Problem**: Loss oscillations, non-convergence
- **Solutions**: Proper learning rates, spectral normalization, gradient penalties

### Vanishing Gradients
- **Problem**: Generator receives no useful gradient signal
- **Solutions**: Alternative loss functions (LSGAN, WGAN), careful architecture design


## 3.4 GAN Variants and Improvements

### DCGAN (Deep Convolutional GAN)
- Uses convolutional layers throughout
- Batch normalization and proper activation functions
- Architecture guidelines for stable training

### WGAN (Wasserstein GAN)
- Different loss function based on Earth Mover's distance
- More stable training
- Meaningful loss curves

### StyleGAN
- Progressive growing
- Style-based generator architecture
- High-quality image generation

### Conditional GANs
- Add label information to both G and D
- Control over generated content
- Applications in image-to-image translation


In [None]:
# Example: Conditional GAN for MNIST
class ConditionalGenerator(nn.Module):
    def __init__(self, nz=100, num_classes=10, ngf=64, nc=1):
        super(ConditionalGenerator, self).__init__()
        self.num_classes = num_classes
        
        # Embedding for class labels
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, noise, labels):
        # Embed labels and concatenate with noise
        label_embedding = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
        gen_input = torch.cat((noise, label_embedding), 1)
        return self.main(gen_input)

# Example usage of conditional generator
cond_gen = ConditionalGenerator().to(device)
print("Conditional Generator created!")
print(f"Input size: noise({nz}) + labels({10}) = {nz + 10}")

# Generate samples for specific digits
with torch.no_grad():
    # Create noise
    test_noise = torch.randn(10, nz, 1, 1, device=device)
    # Create labels (digits 0-9)
    test_labels = torch.arange(0, 10, device=device)
    
    # Note: This would need to be trained first
    print(f"Generated conditional samples for digits: {test_labels.tolist()}")


## 3.5 Evaluation Metrics for GANs

### Inception Score (IS)
- Measures quality and diversity of generated images
- Uses pre-trained Inception network
- Higher scores indicate better quality

### Fréchet Inception Distance (FID)
- Compares feature distributions of real and generated images
- Lower scores indicate better quality
- More robust than IS

### Precision and Recall
- Precision: Quality of generated samples
- Recall: Diversity/coverage of the data distribution
- Helps understand mode collapse

### Human Evaluation
- Still the gold standard for many applications
- Subjective but reliable
- Important for applications like art generation


## Summary

In this chapter, we covered:

1. **GAN Theory**: The adversarial training framework and game theory
2. **Implementation**: Complete PyTorch implementation of a DCGAN
3. **Training**: Practical considerations and common challenges
4. **Variants**: Overview of important GAN improvements
5. **Evaluation**: Methods to assess GAN performance

### Key Takeaways:
- GANs use adversarial training between generator and discriminator
- Proper architecture and training techniques are crucial for stability
- Many variants exist to address specific problems and applications
- Evaluation remains challenging but multiple metrics are available

### Next Steps:
- Experiment with different architectures
- Try conditional GANs for controlled generation
- Explore advanced variants like StyleGAN or BigGAN
- Apply GANs to your specific domain/dataset
