# Generative Adversarial Networks (GANs): Implementation Guide

This notebook demonstrates GAN implementations from basic vanilla GANs to more advanced architectures.

## Table of Contents
1. [Setup and Imports](#setup)
2. [Vanilla GAN](#vanilla)
3. [Deep Convolutional GAN (DCGAN)](#dcgan)
4. [Training on MNIST](#mnist)
5. [Conditional GAN](#conditional)
6. [Evaluation and Visualization](#evaluation)

## 1. Setup and Imports <a name="setup"></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Vanilla GAN <a name="vanilla"></a>

Let's start with a simple fully-connected GAN.

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)  # For MNIST

# Initialize models
generator = Generator(latent_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

print("Generator architecture:")
print(generator)
print("\nDiscriminator architecture:")
print(discriminator)

## 3. Deep Convolutional GAN (DCGAN) <a name="dcgan"></a>

DCGAN uses convolutional layers for better image generation.

In [None]:
class DCGenerator(nn.Module):
    def __init__(self, latent_dim, channels):
        super(DCGenerator, self).__init__()
        
        self.init_size = 7  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class DCDiscriminator(nn.Module):
    def __init__(self, channels):
        super(DCDiscriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        
        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        # Calculate the size of the final feature map
        ds_size = 28 // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
    
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

# Initialize DCGAN models
dc_generator = DCGenerator(latent_dim, img_shape[0]).to(device)
dc_discriminator = DCDiscriminator(img_shape[0]).to(device)

print("DCGAN Generator:")
print(dc_generator)
print("\nDCGAN Discriminator:")
print(dc_discriminator)

## 4. Training on MNIST <a name="mnist"></a>

Let's train our GAN on the MNIST dataset.

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)

# Loss function and optimizers
adversarial_loss = nn.BCELoss()

optimizer_G = optim.Adam(dc_generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(dc_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training function
def train_gan(generator, discriminator, dataloader, num_epochs=10):
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')):
            
            # Adversarial ground truths
            valid = torch.ones(imgs.size(0), 1).to(device)
            fake = torch.zeros(imgs.size(0), 1).to(device)
            
            real_imgs = imgs.to(device)
            
            # ---------------------
            #  Train Generator
            # ---------------------
            optimizer_G.zero_grad()
            
            # Sample noise as generator input
            z = torch.randn(imgs.size(0), latent_dim).to(device)
            
            # Generate a batch of images
            gen_imgs = generator(z)
            
            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            
            g_loss.backward()
            optimizer_G.step()
            
            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            
            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            
            d_loss.backward()
            optimizer_D.step()
        
        print(f"[Epoch {epoch+1}/{num_epochs}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
        
        # Save generated images every epoch
        if (epoch + 1) % 5 == 0:
            with torch.no_grad():
                z = torch.randn(16, latent_dim).to(device)
                gen_imgs = generator(z).cpu()
                
                fig, axes = plt.subplots(4, 4, figsize=(8, 8))
                for idx, ax in enumerate(axes.flatten()):
                    ax.imshow(gen_imgs[idx, 0], cmap='gray')
                    ax.axis('off')
                plt.suptitle(f'Generated Images - Epoch {epoch+1}')
                plt.tight_layout()
                plt.show()

# Train the model (uncomment to run)
# train_gan(dc_generator, dc_discriminator, dataloader, num_epochs=20)

## 5. Conditional GAN <a name="conditional"></a>

Conditional GANs allow us to generate specific types of images by conditioning on labels.

In [None]:
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_shape):
        super(ConditionalGenerator, self).__init__()
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim + num_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    
    def forward(self, noise, labels):
        # Concatenate label embedding and noise
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.img_shape)
        return img

class ConditionalDiscriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super(ConditionalDiscriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        
        self.model = nn.Sequential(
            nn.Linear(num_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        # Concatenate label embedding and image
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

# Initialize conditional GAN
num_classes = 10
cond_generator = ConditionalGenerator(latent_dim, num_classes, img_shape).to(device)
cond_discriminator = ConditionalDiscriminator(num_classes, img_shape).to(device)

print("Conditional Generator initialized")
print("Conditional Discriminator initialized")

## 6. Evaluation and Visualization <a name="evaluation"></a>

Functions for evaluating and visualizing GAN outputs.

In [None]:
def generate_and_visualize(generator, num_samples=16, conditional=False, labels=None):
    """Generate and visualize samples from the generator"""
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim).to(device)
        
        if conditional and labels is not None:
            labels_tensor = torch.LongTensor(labels).to(device)
            gen_imgs = generator(z, labels_tensor).cpu()
        else:
            gen_imgs = generator(z).cpu()
        
        # Denormalize images
        gen_imgs = (gen_imgs + 1) / 2
        
        # Plot
        fig, axes = plt.subplots(4, 4, figsize=(10, 10))
        for idx, ax in enumerate(axes.flatten()):
            ax.imshow(gen_imgs[idx, 0], cmap='gray')
            if conditional and labels is not None:
                ax.set_title(f'Label: {labels[idx]}')
            ax.axis('off')
        plt.tight_layout()
        plt.show()

def interpolate_latent_space(generator, num_steps=10):
    """Visualize interpolation in latent space"""
    generator.eval()
    with torch.no_grad():
        # Generate two random points in latent space
        z1 = torch.randn(1, latent_dim).to(device)
        z2 = torch.randn(1, latent_dim).to(device)
        
        # Interpolate
        alphas = np.linspace(0, 1, num_steps)
        interpolated = []
        
        for alpha in alphas:
            z = (1 - alpha) * z1 + alpha * z2
            img = generator(z).cpu()
            img = (img + 1) / 2
            interpolated.append(img)
        
        # Plot
        fig, axes = plt.subplots(1, num_steps, figsize=(20, 2))
        for idx, ax in enumerate(axes):
            ax.imshow(interpolated[idx][0, 0], cmap='gray')
            ax.axis('off')
        plt.suptitle('Latent Space Interpolation')
        plt.tight_layout()
        plt.show()

# Example usage (uncomment after training)
# generate_and_visualize(dc_generator)
# interpolate_latent_space(dc_generator)

# For conditional GAN
# labels = list(range(10)) + list(range(6))  # Generate digits 0-9 and 0-5
# generate_and_visualize(cond_generator, conditional=True, labels=labels)

## Summary

In this notebook, we covered:

1. **Vanilla GAN**: Basic fully-connected architecture
2. **DCGAN**: Convolutional architecture for better image generation
3. **Training**: Complete training loop with visualization
4. **Conditional GAN**: Generating specific classes of images
5. **Evaluation**: Visualization and latent space interpolation

## Next Steps

- Implement WGAN and WGAN-GP for more stable training
- Try StyleGAN for high-quality image generation
- Experiment with different datasets (CIFAR-10, CelebA)
- Implement evaluation metrics (Inception Score, FID)
- Explore image-to-image translation (Pix2Pix, CycleGAN)

## Tips for Better Results

1. Balance generator and discriminator training
2. Use label smoothing for discriminator
3. Add noise to discriminator inputs
4. Monitor losses and generated samples regularly
5. Experiment with different architectures and hyperparameters