## **Project: A Trilogy of GANs on MNIST**

This notebook implements and compares three foundational types of Generative Adversarial Networks.

*   **Part 1: Vanilla GAN**: The original GAN architecture using Binary Cross-Entropy loss.
*   **Part 2: Wasserstein GAN (WGAN)**: An improved architecture using Wasserstein loss and weight clipping to enhance training stability.
*   **Part 3: Spectral Normalization GAN (SNGAN)**: A modern approach using Spectral Normalization and Hinge Loss for even more stable and high-quality training.

### **Common Setup: Imports and Reproducibility**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.utils import save_image
import numpy as np
import random

In [None]:
# --- Set a Global Random Seed for Reproducibility ---
# This ensures that weight initializations, data shuffling, and noise vectors are the same for each run,
# allowing for a fair comparison between the models.
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # For multi-GPU setups
    # The following two lines are for full reproducibility on GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
latent_dim = 100
image_size = 28 * 28
batch_size = 128
num_epochs = 50

# Image processing and dataset loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

---

## **Part 1: Vanilla GAN with BCE Loss**

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid() # Outputs a probability
        )

    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, image_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
learning_rate_vanilla = 0.0002

discriminator_vanilla = Discriminator().to(device)
generator_vanilla = Generator().to(device)
criterion = nn.BCELoss()
d_optimizer_vanilla = optim.Adam(discriminator_vanilla.parameters(), lr=learning_rate_vanilla)
g_optimizer_vanilla = optim.Adam(generator_vanilla.parameters(), lr=learning_rate_vanilla)
d_losses_vanilla, g_losses_vanilla = [], []

os.makedirs('vanilla_gan_samples', exist_ok=True)

print("Starting Vanilla GAN Training...")
for epoch in range(num_epochs):
    epoch_d_loss, epoch_g_loss = 0.0, 0.0
    for i, (images, _) in enumerate(train_loader):
        real_labels = torch.ones(images.size(0), 1).to(device)
        fake_labels = torch.zeros(images.size(0), 1).to(device)
        images = images.reshape(-1, image_size).to(device)

        d_optimizer_vanilla.zero_grad()
        d_real_outputs = discriminator_vanilla(images)
        d_loss_real = criterion(d_real_outputs, real_labels)
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images = generator_vanilla(z)
        d_fake_outputs = discriminator_vanilla(fake_images.detach())
        d_loss_fake = criterion(d_fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer_vanilla.step()

        g_optimizer_vanilla.zero_grad()
        g_outputs = discriminator_vanilla(fake_images)
        g_loss = criterion(g_outputs, real_labels)
        g_loss.backward()
        g_optimizer_vanilla.step()
        
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()

    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses_vanilla.append(avg_d_loss)
    g_losses_vanilla.append(avg_g_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {avg_d_loss:.4f}, G Loss: {avg_g_loss:.4f}')

print("Vanilla GAN Training finished.")

---

## **Part 2: Wasserstein GAN (WGAN) with Weight Clipping**

In [None]:
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1) # No Sigmoid for WGAN critic
        )

    def forward(self, x):
        return self.model(x)

In [None]:
learning_rate_wgan = 0.00005
critic_iterations = 5
clip_value = 0.01

critic_wgan = Critic().to(device)
generator_wgan = Generator().to(device)
d_optimizer_wgan = optim.RMSprop(critic_wgan.parameters(), lr=learning_rate_wgan)
g_optimizer_wgan = optim.RMSprop(generator_wgan.parameters(), lr=learning_rate_wgan)
d_losses_wgan, g_losses_wgan = [], []

os.makedirs('wgan_samples', exist_ok=True)

print("Starting WGAN Training...")
for epoch in range(num_epochs):
    epoch_d_loss, epoch_g_loss = 0.0, 0.0
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(-1, image_size).to(device)

        for _ in range(critic_iterations):
            d_optimizer_wgan.zero_grad()
            z = torch.randn(images.size(0), latent_dim).to(device)
            fake_images = generator_wgan(z).detach()
            real_output = critic_wgan(images)
            fake_output = critic_wgan(fake_images)
            d_loss = -(torch.mean(real_output) - torch.mean(fake_output))
            d_loss.backward()
            d_optimizer_wgan.step()
            for p in critic_wgan.parameters():
                p.data.clamp_(-clip_value, clip_value)
        
        epoch_d_loss += d_loss.item()

        g_optimizer_wgan.zero_grad()
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images_for_g = generator_wgan(z)
        g_loss = -torch.mean(critic_wgan(fake_images_for_g))
        g_loss.backward()
        g_optimizer_wgan.step()

        epoch_g_loss += g_loss.item()

    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses_wgan.append(avg_d_loss)
    g_losses_wgan.append(avg_g_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Critic Loss: {avg_d_loss:.4f}, Generator Loss: {avg_g_loss:.4f}')

print("WGAN Training finished.")

---

## **Part 3: Spectral Normalization GAN (SNGAN) with Hinge Loss**

In [None]:
from torch.nn.utils import spectral_norm

class SNGAN_Discriminator(nn.Module):
    def __init__(self):
        super(SNGAN_Discriminator, self).__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Linear(image_size, 512)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(256, 1))
        )

    def forward(self, x):
        return self.model(x)

In [None]:
learning_rate_sngan = 0.0002

discriminator_sngan = SNGAN_Discriminator().to(device)
generator_sngan = Generator().to(device)
d_optimizer_sngan = optim.Adam(discriminator_sngan.parameters(), lr=learning_rate_sngan, betas=(0.5, 0.999))
g_optimizer_sngan = optim.Adam(generator_sngan.parameters(), lr=learning_rate_sngan, betas=(0.5, 0.999))
d_losses_sngan, g_losses_sngan = [], []

os.makedirs('sngan_samples', exist_ok=True)

print("Starting SNGAN Training...")
for epoch in range(num_epochs):
    epoch_d_loss, epoch_g_loss = 0.0, 0.0
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(-1, image_size).to(device)

        d_optimizer_sngan.zero_grad()
        real_output = discriminator_sngan(images)
        d_loss_real = torch.mean(nn.ReLU()(1.0 - real_output))
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images = generator_sngan(z).detach()
        fake_output = discriminator_sngan(fake_images)
        d_loss_fake = torch.mean(nn.ReLU()(1.0 + fake_output))
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer_sngan.step()
        epoch_d_loss += d_loss.item()

        g_optimizer_sngan.zero_grad()
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images_for_g = generator_sngan(z)
        g_output = discriminator_sngan(fake_images_for_g)
        g_loss = -torch.mean(g_output)
        g_loss.backward()
        g_optimizer_sngan.step()
        epoch_g_loss += g_loss.item()

    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses_sngan.append(avg_d_loss)
    g_losses_sngan.append(avg_g_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Discriminator Loss: {avg_d_loss:.4f}, Generator Loss: {avg_g_loss:.4f}')

print("SNGAN Training finished.")

---

## **Part 4: Evaluation and Comparison**

#### **Loss Curve Plots**

In [None]:
plt.figure(figsize=(18, 5))

plt.subplot(1, 3, 1)
plt.title("Vanilla GAN Loss")
plt.plot(d_losses_vanilla, label="D Loss")
plt.plot(g_losses_vanilla, label="G Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 3, 2)
plt.title("WGAN Loss")
plt.plot(d_losses_wgan, label="Critic Loss")
plt.plot(g_losses_wgan, label="G Loss")
plt.xlabel("Epochs")
plt.legend()

plt.subplot(1, 3, 3)
plt.title("SNGAN Loss")
plt.plot(d_losses_sngan, label="D Loss")
plt.plot(g_losses_sngan, label="G Loss")
plt.xlabel("Epochs")
plt.legend()

plt.tight_layout()
plt.savefig('all_loss_curves.png')
plt.show()

#### **Final Generated Samples (Fair Comparison)**

In [None]:
# --- Create ONE fixed latent vector for the final comparison ---
# This ensures we are comparing what each generator produces from the exact same input noise.
fixed_z = torch.randn(64, latent_dim).to(device)

def show_final_samples(generator, title, filename, z_vector):
    """Generates and displays samples from a generator using a fixed noise vector."""
    generator.eval() # Set generator to evaluation mode
    with torch.no_grad():
        # Use the fixed z_vector passed as an argument
        final_images = generator(z_vector).view(-1, 1, 28, 28)
        save_image(final_images, filename, normalize=True)
        
        grid = torchvision.utils.make_grid(final_images, nrow=8, normalize=True)
        plt.figure(figsize=(8,8))
        plt.imshow(grid.permute(1, 2, 0).cpu())
        plt.title(title)
        plt.axis('off')
        plt.show()

# --- Generate and display samples using the SAME fixed_z for all models ---
print("Displaying final samples. Each model received the exact same input noise.")

show_final_samples(generator_vanilla, 'Final Vanilla GAN Samples', 'final_vanilla_gan_samples.png', fixed_z)
show_final_samples(generator_wgan, 'Final WGAN Samples', 'final_wgan_samples.png', fixed_z)
show_final_samples(generator_sngan, 'Final SNGAN Samples', 'final_sngan_samples.png', fixed_z)

#### **IS and FID Score Calculation**

In [None]:
!pip install torch-fidelity

def generate_for_evaluation(generator, eval_dir, num_images=10000):
    os.makedirs(eval_dir, exist_ok=True)
    eval_batch_size = 100
    print(f"Generating {num_images} images into '{eval_dir}'...")
    generator.eval()
    with torch.no_grad():
        for i in range(0, num_images, eval_batch_size):
            z = torch.randn(eval_batch_size, latent_dim).to(device)
            generated_images = generator(z).view(-1, 1, 28, 28)
            for j in range(generated_images.size(0)):
                save_image(generated_images[j, :, :, :], os.path.join(eval_dir, f'img_{i+j}.png'), normalize=True)
    print(f"Finished generating images for {eval_dir}.")

generate_for_evaluation(generator_vanilla, 'vanilla_eval_images')
generate_for_evaluation(generator_wgan, 'wgan_eval_images')
generate_for_evaluation(generator_sngan, 'sngan_eval_images')

real_images_dir = 'real_mnist_images'
if not os.path.exists(real_images_dir):
    os.makedirs(real_images_dir, exist_ok=True)
    print(f"Saving real MNIST images to '{real_images_dir}'...")
    real_train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
    img_num = 0
    for i, (images, _) in enumerate(real_train_loader):
        for j in range(images.size(0)):
            save_image(images[j, :, :, :], os.path.join(real_images_dir, f'real_img_{num_images}.png'), normalize=True)
            img_num += 1
    print("Finished saving real images.")
else:
    print(f"Real images directory '{real_images_dir}' already exists.")

#### **Run Evaluation Commands**

In [None]:
print("--- Calculating metrics for Vanilla GAN ---")
!python -m torch_fidelity.fidelity --gpu 0 --fid --isc --input1 /content/vanilla_eval_images --input2 /content/real_mnist_images

print("\n--- Calculating metrics for WGAN ---")
!python -m torch_fidelity.fidelity --gpu 0 --fid --isc --input1 /content/wgan_eval_images --input2 /content/real_mnist_images

print("\n--- Calculating metrics for SNGAN ---")
!python -m torch_fidelity.fidelity --gpu 0 --fid --isc --input1 /content/sngan_eval_images --input2 /content/real_mnist_images

## **Final Comparison**
After running the notebook and calculating the metrics, this table is filled in with the final scores to compare the performance of the three models.

| Metric | Vanilla GAN (BCE Loss) | WGAN (Wasserstein Loss) | SNGAN (Hinge Loss) | Analysis |
| :--- | :--- | :--- | :--- | :--- |
| **Stabilization Method** | None (uses Sigmoid) | Weight Clipping | Spectral Normalization | SNGAN is theoretically the most robust method, but WGAN's weight clipping proved highly effective in this experiment. |
| **Training Stability** | Loss is volatile and does not correlate well with image quality. | Loss is more stable and meaningful, providing a useful proxy for convergence. | Also provides stable and meaningful loss curves. | Both WGAN and SNGAN offer a significant improvement in training stability over the Vanilla GAN. |
| **Visual Quality** | Suffers from significant noise and mode collapse is a risk. | Produces much cleaner images with fewer artifacts, demonstrating better convergence. | Shows good potential but was likely limited by hyperparameter tuning in this run. | The WGAN produced the most visually appealing and realistic samples in this specific experiment. |
| **Inception Score (IS)** | **2.10** | **2.09** | **1.86** | The Vanilla GAN and WGAN produced similarly diverse and recognizable images. The lower SNGAN score suggests it had not converged to its optimal state. |
| **Fréchet Distance (FID)** | **143.59** | **91.09** | **108.99** | **WGAN is the clear winner.** Its FID score shows a massive improvement in realism over the Vanilla GAN. The SNGAN also improved upon the Vanilla GAN but was outperformed by the well-tuned WGAN in this test. |

## Overall Insights and Observations

This series of experiments provides a clear progression in GAN technology and highlights key trade-offs in model design and training. By holding the generator architecture constant and only changing the discriminator and loss function, we can draw direct conclusions about the effectiveness of each method.

### Training Stability and Convergence Speed

*   **Vanilla GAN**: The loss curves confirm the instability of the original GAN formulation. The losses for the generator and discriminator oscillate wildly, and their values are not a reliable indicator of image quality. A lower loss does not necessarily mean better images.
*   **WGAN**: The training process is visibly more stable. The critic's loss provides a much more meaningful metric that correlates with image quality; as the generator improves, the critic's loss (approximating the Earth Mover's distance) tends to decrease. While training is slower *per epoch* due to the multiple critic updates, the *convergence to high-quality samples* is significantly faster and more reliable than the Vanilla GAN.
*   **SNGAN**: This model also demonstrated stable training, similar to WGAN. Its loss curves were smooth and did not diverge. However, its final image quality (as measured by FID) was not as good as WGAN's in this specific experiment, suggesting that while the training was stable, it did not converge to an optimal solution within 50 epochs with the given hyperparameters.

### Mode Collapse

Mode collapse is a classic failure mode where the generator produces only a limited variety of samples. 
*   The **Vanilla GAN** is highly susceptible to this, although our run (with an IS of 2.10) managed to avoid a catastrophic collapse and produced a diverse set of digits.
*   Both **WGAN** and **SNGAN** are specifically designed to mitigate mode collapse by providing more stable gradients that encourage the generator to explore the entire data distribution. The diversity seen in their output grids supports their effectiveness in this regard.

### Analysis of Evaluation Metrics (IS and FID)

The quantitative scores tell a compelling story:

*   **FID as the Key Metric**: The Fréchet Inception Distance proved to be the most decisive metric for judging realism. The **WGAN's FID of 91.09** represents a massive leap in quality over the **Vanilla GAN's 143.59**. This numerically validates that WGAN's generated distribution is much closer to the real one. The **SNGAN's FID of 108.99**, while a major improvement over the Vanilla GAN, did not reach the level of the WGAN.

*   **The SNGAN Result - A Lesson in Hyperparameters**: The fact that SNGAN did not outperform WGAN is an important lesson. SNGAN is often considered state-of-the-art for GAN stabilization, but its performance is not automatic. This result strongly suggests that the chosen learning rate and Adam optimizer parameters were a better fit for WGAN's training dynamics than for SNGAN's. The SNGAN would likely surpass the WGAN with further tuning (e.g., adjusting the learning rate or training for more epochs).

### Final Conclusion

In this controlled experiment, the **WGAN with weight clipping delivered the best overall performance**, achieving the lowest FID score and demonstrating a great balance of training stability and high-quality results with minimal tuning. 

The experiment successfully demonstrates that moving beyond the original BCE loss to more advanced frameworks like Wasserstein distance or Spectral Normalization provides a clear, measurable, and significant improvement to the stability and performance of Generative Adversarial Networks.