# CSET419 – Introduction to Generative AI
## Lab 2 – Experiment 1
### Training a Basic GAN for MNIST Image Generation

**Objective:**  
To design and train a Generative Adversarial Network (GAN) to generate synthetic handwritten digit images using the MNIST dataset and evaluate their quality.

**Platform:** Google Colab  
**Framework:** PyTorch


### Step 1: Import Required Libraries


In [21]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader


In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


### Step 2: Set Device and Input Parameters


In [23]:
# ===== User Inputs =====
dataset_choice = 'mnist'   # only mnist for now
epochs = 50
batch_size = 64
noise_dim = 100
learning_rate = 0.0002
save_interval = 5


### Step 3: Load and Preprocess MNIST Dataset


In [24]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True
)


### Step 4: Define the Discriminator Network


In [25]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


In [26]:
discriminator = Discriminator().to(device)


### Step 5: Define the Generator Network


In [27]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),

            nn.Linear(256, 512),
            nn.ReLU(),

            nn.Linear(512, 1024),
            nn.ReLU(),

            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, z):
        img_flat = self.model(z)
        img = img_flat.view(z.size(0), 1, 28, 28)
        return img


In [28]:
generator = Generator().to(device)


### Step 6: Define Loss Function and Optimizers


In [29]:
adversarial_loss = nn.BCELoss()


In [30]:
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)


In [32]:
real_label = 1.0
fake_label = 0.0


### Step 7: Train the GAN Model


In [33]:
os.makedirs("generated_samples", exist_ok=True)

for epoch in range(1, epochs + 1):

    for i, (real_images, _) in enumerate(dataloader):

        batch_size_curr = real_images.size(0)

        # ======================
        #  Train Discriminator
        # ======================
        discriminator.zero_grad()

        # Real images
        real_images = real_images.to(device)
        real_targets = torch.full((batch_size_curr, 1), real_label, device=device)

        real_output = discriminator(real_images)
        d_loss_real = adversarial_loss(real_output, real_targets)

        # Fake images
        noise = torch.randn(batch_size_curr, noise_dim, device=device)
        fake_images = generator(noise)

        fake_targets = torch.full((batch_size_curr, 1), fake_label, device=device)
        fake_output = discriminator(fake_images.detach())
        d_loss_fake = adversarial_loss(fake_output, fake_targets)

        # Total Discriminator loss
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # ======================
        #  Train Generator
        # ======================
        generator.zero_grad()

        fake_targets = torch.full((batch_size_curr, 1), real_label, device=device)
        output = discriminator(fake_images)
        g_loss = adversarial_loss(output, fake_targets)

        g_loss.backward()
        optimizer_G.step()

    # ======================
    # Logging
    # ======================
    print(f"Epoch {epoch}/{epochs} | D_loss: {d_loss.item():.4f} | G_loss: {g_loss.item():.4f}")

    # ======================
    # Save Images
    # ======================
    if epoch % save_interval == 0:
        save_image(fake_images[:25], f"generated_samples/epoch_{epoch:02d}.png",
                   nrow=5, normalize=True)


Epoch 1/50 | D_loss: 0.0195 | G_loss: 5.7767
Epoch 2/50 | D_loss: 0.7622 | G_loss: 3.1844
Epoch 3/50 | D_loss: 1.4343 | G_loss: 2.0612
Epoch 4/50 | D_loss: 0.9354 | G_loss: 1.3968
Epoch 5/50 | D_loss: 0.7687 | G_loss: 2.2506
Epoch 6/50 | D_loss: 0.3121 | G_loss: 4.4063
Epoch 7/50 | D_loss: 0.8148 | G_loss: 3.3540
Epoch 8/50 | D_loss: 0.3487 | G_loss: 4.3838
Epoch 9/50 | D_loss: 0.3353 | G_loss: 4.1856
Epoch 10/50 | D_loss: 0.1572 | G_loss: 5.1089
Epoch 11/50 | D_loss: 0.1725 | G_loss: 3.5568
Epoch 12/50 | D_loss: 0.2362 | G_loss: 4.1887
Epoch 13/50 | D_loss: 0.1683 | G_loss: 5.4852
Epoch 14/50 | D_loss: 0.1700 | G_loss: 8.8803
Epoch 15/50 | D_loss: 0.4032 | G_loss: 3.8065
Epoch 16/50 | D_loss: 0.4234 | G_loss: 6.3318
Epoch 17/50 | D_loss: 0.3881 | G_loss: 4.5608
Epoch 18/50 | D_loss: 0.1293 | G_loss: 5.1988
Epoch 19/50 | D_loss: 0.4525 | G_loss: 2.7478
Epoch 20/50 | D_loss: 0.7445 | G_loss: 2.4639
Epoch 21/50 | D_loss: 1.0109 | G_loss: 3.1016
Epoch 22/50 | D_loss: 0.6048 | G_loss: 3.01

### Step 8: Generate and Save Final Images


In [34]:
os.makedirs("final_generated_images", exist_ok=True)

generator.eval()  # evaluation mode

with torch.no_grad():
    noise = torch.randn(100, noise_dim, device=device)
    fake_images = generator(noise)
    save_image(fake_images, "final_generated_images/final_100_images.png",
               nrow=10, normalize=True)


In [35]:
class MNIST_Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.model(x)


### Step 9: Predict Labels of Generated Images


In [36]:
classifier = MNIST_Classifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_C = optim.Adam(classifier.parameters(), lr=0.001)

classifier.train()
for epoch in range(3):  # small training
    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer_C.zero_grad()
        outputs = classifier(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_C.step()


In [37]:
classifier.eval()

with torch.no_grad():
    outputs = classifier(fake_images)
    predicted_labels = torch.argmax(outputs, dim=1)

# Label distribution
labels, counts = torch.unique(predicted_labels, return_counts=True)

print("Label Distribution of Generated Images:")
for l, c in zip(labels.cpu().numpy(), counts.cpu().numpy()):
    print(f"Digit {l}: {c} images")


Label Distribution of Generated Images:
Digit 0: 3 images
Digit 1: 25 images
Digit 2: 2 images
Digit 3: 10 images
Digit 4: 13 images
Digit 5: 6 images
Digit 6: 10 images
Digit 7: 13 images
Digit 8: 10 images
Digit 9: 8 images


## Results and Observations

- The Generator gradually learned to produce realistic handwritten digit images.
- Image quality improved over training epochs.
- Generated images were saved periodically during training.
- A trained MNIST classifier was able to predict meaningful digit labels from the generated images.
