<a href="https://colab.research.google.com/github/Jawahars/ai-gan/blob/main/gan-1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **GAN.ipynb Documentation**

This notebook provides a complete implementation of a Generative Adversarial Network (GAN) trained on the MNIST dataset. It includes all necessary steps from data loading to model evaluation.

**Deliverables Covered:**
1.  **Code for vanilla GAN**: `Generator` and `Discriminator` classes are implemented.
2.  **Loss Logging**: Generator & discriminator losses are logged per epoch.
3.  **Sample Generation**: Generated samples are saved every 5 epochs and at the end of training.
4.  **Loss Plots**: Loss curves are plotted after training.
5.  **IS and FID Scores**: Code is included to generate images and run the evaluation command.

#### **Cell 1: Imports**
Import all the necessary libraries for the project.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.utils import save_image

#### **Cell 2: Device and Hyperparameters**
This cell sets up the device (GPU or CPU) and defines the key hyperparameters for the model and training process.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
latent_dim = 100
image_size = 28 * 28
batch_size = 128
learning_rate = 0.0002
num_epochs = 50

#### **Cell 3: Data Loading and Preprocessing**
Here, we load the MNIST dataset. Images are converted to tensors and normalized to the range [-1, 1] to match the generator's output activation (`tanh`).

In [3]:
# Image processing and dataset loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 17.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 479kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.77MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.14MB/s]


#### **Cell 4: Discriminator Network**
The Discriminator is a standard feedforward neural network that classifies images as real or fake. It outputs a single probability value.

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

#### **Cell 5: Generator Network**
The Generator takes a random noise vector from the latent space and transforms it into a flattened image.

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, image_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

#### **Cell 6: Model Initialization, Loss, and Optimizers**
We initialize the models, move them to the configured device, and set up the Binary Cross-Entropy loss function and Adam optimizers.

In [6]:
# Initialize models
discriminator = Discriminator().to(device)
generator = Generator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)

#### **Cell 7: Training Loop with Logging**
This is the main training loop. It includes logic to:
1.  Calculate and store the average loss for both the generator and discriminator at the end of each epoch.
2.  Save a grid of generated image samples every 5 epochs to a directory named `gan_samples`.

In [None]:
# Lists to store loss history for plotting
d_losses = []
g_losses = []

# Create a directory to save samples
os.makedirs('gan_samples', exist_ok=True)

print("Starting Training...")
for epoch in range(num_epochs):
    epoch_d_loss = 0.0
    epoch_g_loss = 0.0
    for i, (images, _) in enumerate(train_loader):
        # Prepare real and fake labels
        real_labels = torch.ones(images.size(0), 1).to(device)
        fake_labels = torch.zeros(images.size(0), 1).to(device)

        # Reshape images for the linear layers
        images = images.reshape(-1, image_size).to(device)

        # --- Train Discriminator ---
        d_optimizer.zero_grad()

        # Loss on real images
        d_real_outputs = discriminator(images)
        d_loss_real = criterion(d_real_outputs, real_labels)

        # Loss on fake images
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images = generator(z)
        d_fake_outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(d_fake_outputs, fake_labels)

        # Total discriminator loss and backpropagation
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()

        # --- Train Generator ---
        g_optimizer.zero_grad()
        # We need to run the fake images through the discriminator again
        g_outputs = discriminator(fake_images)
        g_loss = criterion(g_outputs, real_labels) # Generator wants discriminator to think these are real
        g_loss.backward()
        g_optimizer.step()

        # Accumulate losses for the epoch
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()

    # --- End of Epoch Logging ---
    avg_d_loss = epoch_d_loss / len(train_loader)
    avg_g_loss = epoch_g_loss / len(train_loader)
    d_losses.append(avg_d_loss)
    g_losses.append(avg_g_loss)

    print(f'Epoch [{epoch+1}/{num_epochs}], Discriminator Loss: {avg_d_loss:.4f}, Generator Loss: {avg_g_loss:.4f}')

    # Save generated samples every 5 epochs
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            z_sample = torch.randn(64, latent_dim).to(device)
            sample_images = generator(z_sample).view(-1, 1, 28, 28)
            save_image(sample_images, f'gan_samples/samples_epoch_{epoch+1}.png', normalize=True)

print("Training finished.")

# Save the trained generator model
torch.save(generator.state_dict(), 'generator_final.pth')

Starting Training...
Epoch [1/50], Discriminator Loss: 0.4547, Generator Loss: 3.4025
Epoch [2/50], Discriminator Loss: 0.4281, Generator Loss: 3.6405
Epoch [3/50], Discriminator Loss: 1.0160, Generator Loss: 2.9232
Epoch [4/50], Discriminator Loss: 1.0953, Generator Loss: 2.5737
Epoch [5/50], Discriminator Loss: 1.3304, Generator Loss: 2.0038
Epoch [6/50], Discriminator Loss: 1.0974, Generator Loss: 1.8625
Epoch [7/50], Discriminator Loss: 0.8979, Generator Loss: 1.9266
Epoch [8/50], Discriminator Loss: 0.5899, Generator Loss: 2.3349
Epoch [9/50], Discriminator Loss: 0.3826, Generator Loss: 2.8498


#### **Deliverable 1: Final Trained Generator Samples**
This cell generates and saves the final batch of images from the fully trained generator.

In [None]:
# Generate and save final samples
with torch.no_grad():
    z_final = torch.randn(64, latent_dim).to(device)
    final_images = generator(z_final).view(-1, 1, 28, 28)
    save_image(final_images, 'final_gan_samples.png', normalize=True)

    # Display the final grid
    print("Displaying final generated samples...")
    grid = torchvision.utils.make_grid(final_images, nrow=8, normalize=True)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title('Final Generated Samples')
    plt.axis('off')
    plt.show()

#### **Deliverable 2: Loss Plots**
This cell plots the saved generator and discriminator losses to visualize their behavior over the training epochs.

In [None]:
# Plot the loss curves
print("Plotting loss curves...")
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(g_losses, label="Generator Loss (G)")
plt.plot(d_losses, label="Discriminator Loss (D)")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig('loss_curves.png')
plt.show()

#### **Deliverable 3: IS and FID Scores**
To calculate the Inception Score (IS) and Fréchet Inception Distance (FID), we first need to generate a large number of samples. Then, we can use a library like `torch-fidelity` to compare them against the real dataset.

In [None]:
# First, ensure the evaluation library is installed
!pip install torch-fidelity

# --- Generate a large number of images for evaluation ---
eval_dir = 'eval_images'
os.makedirs(eval_dir, exist_ok=True)

num_eval_images = 10000 # Standard number for FID/IS
eval_batch_size = 100

print(f"Generating {num_eval_images} images for evaluation...")
generator.eval() # Set generator to evaluation mode
with torch.no_grad():
    for i in range(0, num_eval_images, eval_batch_size):
        z = torch.randn(eval_batch_size, latent_dim).to(device)
        generated_images = generator(z).view(-1, 1, 28, 28)

        for j in range(generated_images.size(0)):
            save_image(generated_images[j, :, :, :], os.path.join(eval_dir, f'img_{i+j}.png'), normalize=True)

print("Images generated successfully.")

#### **Final Step: Command for Metric Calculation**

**Important:** The following command should be run in your **terminal/shell** from the same directory as your notebook. It uses `torch-fidelity` to calculate the metrics. The `--input2 mnist-train` argument tells the library to automatically use the PyTorch MNIST training set as the real data for comparison.

You can also run it from a notebook cell by prefixing it with `!`, as shown below.

In [None]:
!torch-fidelity --gpu 0 --fid --isc --input1 eval_images --input2 mnist-train