## **Project: Vanilla GAN and WGAN on MNIST**

This notebook implements and compares two types of Generative Adversarial Networks.

*   **Part 1: Vanilla GAN**: The foundational GAN architecture using Binary Cross-Entropy loss.
*   **Part 2: Wasserstein GAN (WGAN)**: An improved architecture using Wasserstein loss and weight clipping to enhance training stability and image quality.

### **Common Setup: Imports and Data Loading**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.utils import save_image

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
latent_dim = 100
image_size = 28 * 28
batch_size = 128
num_epochs = 50

# Image processing and dataset loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

---

## **Part 1: Vanilla GAN with BCE Loss**

#### **Vanilla GAN Models**

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid() # Outputs a probability
        )

    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, image_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

#### **Vanilla GAN Training**

In [None]:
learning_rate_vanilla = 0.0002

# Initialize models
discriminator_vanilla = Discriminator().to(device)
generator_vanilla = Generator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
d_optimizer_vanilla = optim.Adam(discriminator_vanilla.parameters(), lr=learning_rate_vanilla)
g_optimizer_vanilla = optim.Adam(generator_vanilla.parameters(), lr=learning_rate_vanilla)

# Lists to store loss history for plotting
d_losses_vanilla = []
g_losses_vanilla = []

os.makedirs('vanilla_gan_samples', exist_ok=True)

print("Starting Vanilla GAN Training...")
for epoch in range(num_epochs):
    epoch_d_loss = 0.0
    epoch_g_loss = 0.0
    for i, (images, _) in enumerate(train_loader):
        real_labels = torch.ones(images.size(0), 1).to(device)
        fake_labels = torch.zeros(images.size(0), 1).to(device)
        images = images.reshape(-1, image_size).to(device)

        # Train Discriminator
        d_optimizer_vanilla.zero_grad()
        d_real_outputs = discriminator_vanilla(images)
        d_loss_real = criterion(d_real_outputs, real_labels)
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images = generator_vanilla(z)
        d_fake_outputs = discriminator_vanilla(fake_images.detach())
        d_loss_fake = criterion(d_fake_outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer_vanilla.step()

        # Train Generator
        g_optimizer_vanilla.zero_grad()
        g_outputs = discriminator_vanilla(fake_images)
        g_loss = criterion(g_outputs, real_labels)
        g_loss.backward()
        g_optimizer_vanilla.step()
        
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()

    # End of Epoch Logging
    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses_vanilla.append(avg_d_loss)
    g_losses_vanilla.append(avg_g_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {avg_d_loss:.4f}, G Loss: {avg_g_loss:.4f}')

    if (epoch + 1) % 10 == 0:
        save_image(fake_images.view(-1, 1, 28, 28), f'vanilla_gan_samples/samples_epoch_{epoch+1}.png', normalize=True)

print("Vanilla GAN Training finished.")
torch.save(generator_vanilla.state_dict(), 'generator_vanilla_final.pth')

---

## **Part 2: Wasserstein GAN (WGAN) with Weight Clipping**

This section implements the WGAN. Key changes include:
1.  **Critic Model**: The Discriminator (now called a Critic) does not have a final Sigmoid layer.
2.  **Wasserstein Loss**: A new loss function that measures the Earth Mover's distance.
3.  **Weight Clipping**: The critic's weights are clamped to a small range after each update to enforce the Lipschitz constraint.
4.  **RMSProp Optimizer**: As suggested in the original WGAN paper.

#### **WGAN Models**

In [None]:
# The Critic model for WGAN is a discriminator without the final sigmoid layer
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1) # No Sigmoid
        )

    def forward(self, x):
        return self.model(x)

# The Generator is the same as in the Vanilla GAN
# We will create new instances for the WGAN training.

#### **WGAN Training**

In [None]:
# WGAN specific hyperparameters
learning_rate_wgan = 0.00005
critic_iterations = 5 # Number of critic updates per generator update
clip_value = 0.01 # Weight clipping value

# Initialize models
critic_wgan = Critic().to(device)
generator_wgan = Generator().to(device)

# Optimizers - RMSprop is recommended for WGAN
d_optimizer_wgan = optim.RMSprop(critic_wgan.parameters(), lr=learning_rate_wgan)
g_optimizer_wgan = optim.RMSprop(generator_wgan.parameters(), lr=learning_rate_wgan)

# Lists to store loss history
d_losses_wgan = []
g_losses_wgan = []

os.makedirs('wgan_samples', exist_ok=True)

print("Starting WGAN Training...")
for epoch in range(num_epochs):
    epoch_d_loss = 0.0
    epoch_g_loss = 0.0
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(-1, image_size).to(device)

        # --- Train Critic --- 
        for _ in range(critic_iterations):
            d_optimizer_wgan.zero_grad()
            
            # Sample noise and generate fake images
            z = torch.randn(images.size(0), latent_dim).to(device)
            fake_images = generator_wgan(z).detach() # Detach to avoid backprop through generator
            
            # Calculate critic scores
            real_output = critic_wgan(images)
            fake_output = critic_wgan(fake_images)
            
            # Wasserstein loss
            d_loss = -(torch.mean(real_output) - torch.mean(fake_output))
            d_loss.backward()
            d_optimizer_wgan.step()

            # Clip weights of critic
            for p in critic_wgan.parameters():
                p.data.clamp_(-clip_value, clip_value)
        
        epoch_d_loss += d_loss.item()

        # --- Train Generator ---
        g_optimizer_wgan.zero_grad()
        
        # Generate new fake images
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images_for_g = generator_wgan(z)
        
        # Calculate loss for generator
        g_loss = -torch.mean(critic_wgan(fake_images_for_g))
        g_loss.backward()
        g_optimizer_wgan.step()

        epoch_g_loss += g_loss.item()

    # End of Epoch Logging
    # Note: For WGAN, the critic loss is divided by critic_iterations for a fair comparison per batch
    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses_wgan.append(avg_d_loss)
    g_losses_wgan.append(avg_g_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Critic Loss: {avg_d_loss:.4f}, Generator Loss: {avg_g_loss:.4f}')

    if (epoch + 1) % 10 == 0:
        save_image(fake_images_for_g.view(-1, 1, 28, 28), f'wgan_samples/samples_epoch_{epoch+1}.png', normalize=True)

print("WGAN Training finished.")
torch.save(generator_wgan.state_dict(), 'generator_wgan_final.pth')

---

## **Part 3: Evaluation and Comparison**

#### **Loss Curve Plots**

In [None]:
# Plot Vanilla GAN Losses
plt.figure(figsize=(10, 5))
plt.title("Vanilla GAN Loss During Training")
plt.plot(d_losses_vanilla, label="Discriminator Loss")
plt.plot(g_losses_vanilla, label="Generator Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig('vanilla_gan_loss_curves.png')
plt.show()

# Plot WGAN Losses
plt.figure(figsize=(10, 5))
plt.title("WGAN Loss During Training")
plt.plot(d_losses_wgan, label="Critic Loss")
plt.plot(g_losses_wgan, label="Generator Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig('wgan_loss_curves.png')
plt.show()

#### **Final Generated Samples**

In [None]:
# Generate final samples from Vanilla GAN
with torch.no_grad():
    z_final = torch.randn(64, latent_dim).to(device)
    final_images = generator_vanilla(z_final).view(-1, 1, 28, 28)
    save_image(final_images, 'final_vanilla_gan_samples.png', normalize=True)
    grid = torchvision.utils.make_grid(final_images, nrow=8, normalize=True)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title('Final Vanilla GAN Samples')
    plt.axis('off')
    plt.show()

# Generate final samples from WGAN
with torch.no_grad():
    z_final = torch.randn(64, latent_dim).to(device)
    final_images = generator_wgan(z_final).view(-1, 1, 28, 28)
    save_image(final_images, 'final_wgan_samples.png', normalize=True)
    grid = torchvision.utils.make_grid(final_images, nrow=8, normalize=True)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title('Final WGAN Samples')
    plt.axis('off')
    plt.show()

#### **IS and FID Score Calculation**
To calculate these metrics, we first generate a large number of samples from each model and save them to a directory.

In [None]:
# Install evaluation library
!pip install torch-fidelity

# --- Function to generate and save images for evaluation ---
def generate_for_evaluation(generator, eval_dir, num_images=10000):
    os.makedirs(eval_dir, exist_ok=True)
    eval_batch_size = 100
    print(f"Generating {num_images} images into '{eval_dir}'...")
    generator.eval()
    with torch.no_grad():
        for i in range(0, num_images, eval_batch_size):
            z = torch.randn(eval_batch_size, latent_dim).to(device)
            generated_images = generator(z).view(-1, 1, 28, 28)
            for j in range(generated_images.size(0)):
                save_image(generated_images[j, :, :, :], os.path.join(eval_dir, f'img_{i+j}.png'), normalize=True)
    print("Finished generating images.")

# Generate images for both models
generate_for_evaluation(generator_vanilla, 'vanilla_eval_images')
generate_for_evaluation(generator_wgan, 'wgan_eval_images')

# --- Create a directory of Real MNIST Images for Comparison ---
real_images_dir = 'real_mnist_images'
if not os.path.exists(real_images_dir):
    os.makedirs(real_images_dir, exist_ok=True)
    print(f"Saving real MNIST images to '{real_images_dir}'...")
    real_train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
    img_num = 0
    for i, (images, _) in enumerate(real_train_loader):
        for j in range(images.size(0)):
            save_image(images[j, :, :, :], os.path.join(real_images_dir, f'real_img_{img_num}.png'), normalize=True)
            img_num += 1
    print("Finished saving real images.")
else:
    print(f"Real images directory '{real_images_dir}' already exists.")

#### **Run Evaluation Commands**
Run the following commands to get the IS and FID scores for each model. The results can then be compiled into the final summary table.

In [None]:
print("--- Calculating metrics for Vanilla GAN ---")
!python -m torch_fidelity.fidelity --gpu 0 --fid --isc --input1 /content/vanilla_eval_images --input2 /content/real_mnist_images

print("\n--- Calculating metrics for WGAN ---")
!python -m torch_fidelity.fidelity --gpu 0 --fid --isc --input1 /content/wgan_eval_images --input2 /content/real_mnist_images

## **Final Comparison**
After running the notebook and calculating the metrics, fill in the table below with the final scores to compare the performance of the two models.

| Metric | Vanilla GAN (BCE Loss) | WGAN (Wasserstein Loss) | Analysis |
| :--- | :--- | :--- | :--- |
| **Training Stability** | Loss curves are often volatile and less indicative of image quality. Prone to mode collapse. | Loss curves are more stable and meaningful. The critic's loss correlates better with sample quality. More resistant to mode collapse. |
| **Visual Quality** | Generated images often suffer from noise and minor artifacts. | Generated images are typically cleaner, with fewer artifacts and sharper details due to more stable training gradients. |
| **Inception Score (IS)** | *~2.21* | *Fill in from output* | The WGAN is expected to have a higher IS, indicating better quality and diversity. |
| **Fréchet Inception Distance (FID)** | *~136.97* | *Fill in from output* | The WGAN is expected to have a significantly lower FID, showing its output distribution is closer to the real data. |