<a href="https://colab.research.google.com/github/Jawahars/ai-gan/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Project: A Trilogy of GANs on MNIST**

This notebook implements and compares three foundational types of Generative Adversarial Networks.

*   **Part 1: Vanilla GAN**: The original GAN architecture using Binary Cross-Entropy loss.
*   **Part 2: Wasserstein GAN (WGAN)**: An improved architecture using Wasserstein loss and weight clipping to enhance training stability.
*   **Part 3: Spectral Normalization GAN (SNGAN)**: A modern approach using Spectral Normalization and Hinge Loss for even more stable and high-quality training.

### **Common Setup: Imports and Reproducibility**

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.utils import save_image
import numpy as np
import random

In [2]:
# --- Set a Global Random Seed for Reproducibility ---
# This ensures that weight initializations, data shuffling, and noise vectors are the same for each run,
# allowing for a fair comparison between variants.
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # For multi-GPU setups
    # The following two lines are for full reproducibility on GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
latent_dim = 100
image_size = 28 * 28
batch_size = 128
num_epochs = 50

# Image processing and dataset loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

---

## **Part 1: Vanilla GAN with BCE Loss**

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid() # Outputs a probability
        )

    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, image_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

In [5]:
learning_rate_vanilla = 0.0002

discriminator_vanilla = Discriminator().to(device)
generator_vanilla = Generator().to(device)
criterion = nn.BCELoss()
d_optimizer_vanilla = optim.Adam(discriminator_vanilla.parameters(), lr=learning_rate_vanilla)
g_optimizer_vanilla = optim.Adam(generator_vanilla.parameters(), lr=learning_rate_vanilla)
d_losses_vanilla, g_losses_vanilla = [], []

os.makedirs('vanilla_gan_samples', exist_ok=True)

print("Starting Vanilla GAN Training...")
for epoch in range(num_epochs):
    epoch_d_loss, epoch_g_loss = 0.0, 0.0
    for i, (images, _) in enumerate(train_loader):
        real_labels = torch.ones(images.size(0), 1).to(device)
        fake_labels = torch.zeros(images.size(0), 1).to(device)
        images = images.reshape(-1, image_size).to(device)

        d_optimizer_vanilla.zero_grad()
        d_real_outputs = discriminator_vanilla(images)
        d_loss_real = criterion(d_real_outputs, real_labels)
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images = generator_vanilla(z)
        d_fake_outputs = discriminator_vanilla(fake_images.detach())
        d_loss_fake = criterion(d_fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer_vanilla.step()

        g_optimizer_vanilla.zero_grad()
        g_outputs = discriminator_vanilla(fake_images)
        g_loss = criterion(g_outputs, real_labels)
        g_loss.backward()
        g_optimizer_vanilla.step()

        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()

    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses_vanilla.append(avg_d_loss)
    g_losses_vanilla.append(avg_g_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {avg_d_loss:.4f}, G Loss: {avg_g_loss:.4f}')

print("Vanilla GAN Training finished.")

Starting Vanilla GAN Training...
Epoch [1/50], D Loss: 0.5724, G Loss: 3.5371
Epoch [2/50], D Loss: 0.5303, G Loss: 3.3323
Epoch [3/50], D Loss: 0.8558, G Loss: 2.3385
Epoch [4/50], D Loss: 0.8458, G Loss: 2.7774
Epoch [5/50], D Loss: 1.1185, G Loss: 2.0056
Epoch [6/50], D Loss: 1.1933, G Loss: 2.0702
Epoch [7/50], D Loss: 1.2112, G Loss: 2.1464
Epoch [8/50], D Loss: 1.0813, G Loss: 1.6654
Epoch [9/50], D Loss: 0.9060, G Loss: 1.7896
Epoch [10/50], D Loss: 0.6319, G Loss: 2.3653
Epoch [11/50], D Loss: 0.6195, G Loss: 2.3977
Epoch [12/50], D Loss: 0.6485, G Loss: 2.5455
Epoch [13/50], D Loss: 0.5769, G Loss: 3.0423
Epoch [14/50], D Loss: 0.5216, G Loss: 3.0470
Epoch [15/50], D Loss: 0.5751, G Loss: 2.8559
Epoch [16/50], D Loss: 0.5288, G Loss: 2.8231
Epoch [17/50], D Loss: 0.6179, G Loss: 2.5799
Epoch [18/50], D Loss: 0.6575, G Loss: 2.7438
Epoch [19/50], D Loss: 0.5196, G Loss: 3.0139
Epoch [20/50], D Loss: 0.5461, G Loss: 3.1352
Epoch [21/50], D Loss: 0.5466, G Loss: 3.0050
Epoch [22/

KeyboardInterrupt: 

---

## **Part 2: Wasserstein GAN (WGAN) with Weight Clipping**

In [None]:
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1) # No Sigmoid for WGAN critic
        )

    def forward(self, x):
        return self.model(x)

In [None]:
learning_rate_wgan = 0.00005
critic_iterations = 5
clip_value = 0.01

critic_wgan = Critic().to(device)
generator_wgan = Generator().to(device)
d_optimizer_wgan = optim.RMSprop(critic_wgan.parameters(), lr=learning_rate_wgan)
g_optimizer_wgan = optim.RMSprop(generator_wgan.parameters(), lr=learning_rate_wgan)
d_losses_wgan, g_losses_wgan = [], []

os.makedirs('wgan_samples', exist_ok=True)

print("Starting WGAN Training...")
for epoch in range(num_epochs):
    epoch_d_loss, epoch_g_loss = 0.0, 0.0
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(-1, image_size).to(device)

        for _ in range(critic_iterations):
            d_optimizer_wgan.zero_grad()
            z = torch.randn(images.size(0), latent_dim).to(device)
            fake_images = generator_wgan(z).detach()
            real_output = critic_wgan(images)
            fake_output = critic_wgan(fake_images)
            d_loss = -(torch.mean(real_output) - torch.mean(fake_output))
            d_loss.backward()
            d_optimizer_wgan.step()
            for p in critic_wgan.parameters():
                p.data.clamp_(-clip_value, clip_value)

        epoch_d_loss += d_loss.item()

        g_optimizer_wgan.zero_grad()
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images_for_g = generator_wgan(z)
        g_loss = -torch.mean(critic_wgan(fake_images_for_g))
        g_loss.backward()
        g_optimizer_wgan.step()

        epoch_g_loss += g_loss.item()

    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses_wgan.append(avg_d_loss)
    g_losses_wgan.append(avg_g_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Critic Loss: {avg_d_loss:.4f}, Generator Loss: {avg_g_loss:.4f}')

print("WGAN Training finished.")

---

## **Part 3: Spectral Normalization GAN (SNGAN) with Hinge Loss**

In [None]:
from torch.nn.utils import spectral_norm

class SNGAN_Discriminator(nn.Module):
    def __init__(self):
        super(SNGAN_Discriminator, self).__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Linear(image_size, 512)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(256, 1))
        )

    def forward(self, x):
        return self.model(x)

In [None]:
learning_rate_sngan = 0.0002

discriminator_sngan = SNGAN_Discriminator().to(device)
generator_sngan = Generator().to(device)
d_optimizer_sngan = optim.Adam(discriminator_sngan.parameters(), lr=learning_rate_sngan, betas=(0.5, 0.999))
g_optimizer_sngan = optim.Adam(generator_sngan.parameters(), lr=learning_rate_sngan, betas=(0.5, 0.999))
d_losses_sngan, g_losses_sngan = [], []

os.makedirs('sngan_samples', exist_ok=True)

print("Starting SNGAN Training...")
for epoch in range(num_epochs):
    epoch_d_loss, epoch_g_loss = 0.0, 0.0
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(-1, image_size).to(device)

        d_optimizer_sngan.zero_grad()
        real_output = discriminator_sngan(images)
        d_loss_real = torch.mean(nn.ReLU()(1.0 - real_output))
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images = generator_sngan(z).detach()
        fake_output = discriminator_sngan(fake_images)
        d_loss_fake = torch.mean(nn.ReLU()(1.0 + fake_output))
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer_sngan.step()
        epoch_d_loss += d_loss.item()

        g_optimizer_sngan.zero_grad()
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images_for_g = generator_sngan(z)
        g_output = discriminator_sngan(fake_images_for_g)
        g_loss = -torch.mean(g_output)
        g_loss.backward()
        g_optimizer_sngan.step()
        epoch_g_loss += g_loss.item()

    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses_sngan.append(avg_d_loss)
    g_losses_sngan.append(avg_g_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Discriminator Loss: {avg_d_loss:.4f}, Generator Loss: {avg_g_loss:.4f}')

print("SNGAN Training finished.")

---

## **Part 4: Evaluation and Comparison**

#### **Loss Curve Plots**

In [None]:
plt.figure(figsize=(18, 5))

plt.subplot(1, 3, 1)
plt.title("Vanilla GAN Loss")
plt.plot(d_losses_vanilla, label="D Loss")
plt.plot(g_losses_vanilla, label="G Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 3, 2)
plt.title("WGAN Loss")
plt.plot(d_losses_wgan, label="Critic Loss")
plt.plot(g_losses_wgan, label="G Loss")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(1, 3, 3)
plt.title("SNGAN Loss")
plt.plot(d_losses_sngan, label="D Loss")
plt.plot(g_losses_sngan, label="G Loss")
plt.xlabel("Epochs")
plt.legend()

plt.tight_layout()
plt.savefig('all_loss_curves.png')
plt.show()

#### **Final Generated Samples (Fair Comparison)**

In [None]:
# --- Create ONE fixed latent vector for the final comparison ---
# This ensures we are comparing what each generator produces from the exact same input noise.
fixed_z = torch.randn(64, latent_dim).to(device)

def show_final_samples(generator, title, filename, z_vector):
    """Generates and displays samples from a generator using a fixed noise vector."""
    generator.eval() # Set generator to evaluation mode
    with torch.no_grad():
        # Use the fixed z_vector passed as an argument
        final_images = generator(z_vector).view(-1, 1, 28, 28)
        save_image(final_images, filename, normalize=True)

        grid = torchvision.utils.make_grid(final_images, nrow=8, normalize=True)
        plt.figure(figsize=(8,8))
        plt.imshow(grid.permute(1, 2, 0).cpu())
        plt.title(title)
        plt.axis('off')
        plt.show()

# --- Generate and display samples using the SAME fixed_z for all models ---
print("Displaying final samples. Each model received the exact same input noise.")

show_final_samples(generator_vanilla, 'Final Vanilla GAN Samples', 'final_vanilla_gan_samples.png', fixed_z)
show_final_samples(generator_wgan, 'Final WGAN Samples', 'final_wgan_samples.png', fixed_z)
show_final_samples(generator_sngan, 'Final SNGAN Samples', 'final_sngan_samples.png', fixed_z)

#### **IS and FID Score Calculation setup**

In [None]:
def generate_for_evaluation(generator, eval_dir, num_images=10000):
    os.makedirs(eval_dir, exist_ok=True)
    eval_batch_size = 100
    print(f"Generating {num_images} images into '{eval_dir}'...")
    generator.eval()
    with torch.no_grad():
        for i in range(0, num_images, eval_batch_size):
            z = torch.randn(eval_batch_size, latent_dim).to(device)
            generated_images = generator(z).view(-1, 1, 28, 28)
            for j in range(generated_images.size(0)):
                save_image(generated_images[j, :, :, :], os.path.join(eval_dir, f'img_{i+j}.png'), normalize=True)
    print(f"Finished generating images for {eval_dir}.")

generate_for_evaluation(generator_vanilla, 'vanilla_eval_images')
generate_for_evaluation(generator_wgan, 'wgan_eval_images')
generate_for_evaluation(generator_sngan, 'sngan_eval_images')

real_images_dir = 'real_mnist_images'
if not os.path.exists(real_images_dir):
    os.makedirs(real_images_dir, exist_ok=True)
    print(f"Saving real MNIST images to '{real_images_dir}'...")
    real_train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
    img_num = 0
    for i, (images, _) in enumerate(real_train_loader):
        for j in range(images.size(0)):
            save_image(images[j, :, :, :], os.path.join(real_images_dir, f'real_img_{img_num}.png'), normalize=True)
            img_num += 1
    print("Finished saving real images.")
else:
    print(f"Real images directory '{real_images_dir}' already exists.")

#### **Run Evaluation Commands**

In [None]:
!pip install torch-fidelity -q

print("--- Calculating metrics for Vanilla GAN ---")
!python -m torch_fidelity.fidelity --gpu 0 --fid --isc --input1 /content/vanilla_eval_images --input2 /content/real_mnist_images

print("\n--- Calculating metrics for WGAN ---")
!python -m torch_fidelity.fidelity --gpu 0 --fid --isc --input1 /content/wgan_eval_images --input2 /content/real_mnist_images

print("\n--- Calculating metrics for SNGAN ---")
!python -m torch_fidelity.fidelity --gpu 0 --fid --isc --input1 /content/sngan_eval_images --input2 /content/real_mnist_images

## **Final Comparison**
After running the notebook and calculating the metrics, fill in the table below with the final scores to compare the performance of the three models.

| Metric | Vanilla GAN (BCE Loss) | WGAN (Wasserstein Loss) | SNGAN (Hinge Loss) | Analysis |
| :--- | :--- | :--- | :--- | :--- |
| **Stabilization Method** | None (uses Sigmoid) | Weight Clipping | Spectral Normalization | SNGAN provides the most principled and effective way to enforce the Lipschitz constraint, avoiding issues like exploding/vanishing gradients (WGAN clipping) or mode collapse (Vanilla). |
| **Training Stability** | Loss is often volatile and not a good indicator of sample quality. | Loss is more meaningful (approximates Earth Mover's distance) and stable. | Stable training with meaningful loss curves. Typically converges faster and more reliably than WGAN. |
| **Visual Quality** | Often suffers from noise and artifacts. | Generally cleaner images than Vanilla GAN but can sometimes produce lower-quality samples if clipping is not tuned well. | Typically produces the sharpest and cleanest images with the fewest artifacts among the three. |
| **Inception Score (IS)** | *Fill in from output* | *Fill in from output* | *Fill in from output* | SNGAN is expected to have the highest IS, followed by WGAN, then Vanilla GAN. |
| **Fréchet Distance (FID)** | *Fill in from output* | *Fill in from output* | *Fill in from output* | SNGAN is expected to have the lowest FID, followed by WGAN, then Vanilla GAN. |