# Handwritten Digit Generator Web App in Google Colab

**This script will:**
- 1. Install necessary libraries (PyTorch, Streamlit, pyngrok).
- 2. Define and train a Conditional DCGAN on the MNIST dataset.
- 3. Save the trained Generator model.
- 4. Create a Streamlit web application script.

In [11]:
# --- 1. Install necessary libraries ---
# This part needs to be run only once.
!pip install torch torchvision streamlit pillow --quiet

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

In [18]:
# --- 2. Device configuration and Hyperparameters ---
# This line will automatically detect and use a CUDA-enabled GPU (like T4) if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters for GAN training
latent_dim = 100
num_classes = 10
image_size = 28
num_epochs = 200
batch_size = 256
lr = 0.0001
beta1 = 0.5

Using device: cuda


In [13]:
# --- 3. Data Transformation and Loading ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])

# MNIST Dataset
print("Downloading MNIST dataset...")
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
print("MNIST dataset loaded.")

Downloading MNIST dataset...
MNIST dataset loaded.


In [14]:
#  --- 4. Generator and Discriminator Architectures ---
# Generator model for generating images from latent vectors and digit labels
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Embedding for conditional input (digit label)
        self.label_emb = nn.Embedding(num_classes, num_classes)

        # Main sequential block of the generator
        self.main = nn.Sequential(
            # Input: (latent_dim + num_classes) x 1 x 1 (concatenated noise and label)
            # From 1x1 to 7x7
            nn.ConvTranspose2d(latent_dim + num_classes, 256, 7, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State size: 256 x 7 x 7
            # From 7x7 to 14x14
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State size: 128 x 14 x 14
            # From 14x14 to 28x28
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False), # Output 1 channel for grayscale
            nn.Tanh() # Output pixel values in [-1, 1]
            # Output size: 1 x 28 x 28
        )

    def forward(self, noise, labels):
        # Combine noise and label embedding
        # Label embedding is converted to a vector and concatenated with noise
        gen_input = torch.cat((self.label_emb(labels), noise.view(noise.size(0), -1)), -1)
        # Reshape for ConvTranspose2d (batch_size, channels, 1, 1)
        gen_input = gen_input.view(gen_input.size(0), gen_input.size(1), 1, 1)
        return self.main(gen_input)

# Discriminator model for distinguishing real from fake images
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Embedding for conditional input (digit label)
        self.label_emb = nn.Embedding(num_classes, num_classes)

        # Main sequential block of the discriminator
        self.main = nn.Sequential(
            # Input: (1 + num_classes) x 28 x 28 (after concatenating label)
            nn.Conv2d(1 + num_classes, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: 64 x 14 x 14
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: 128 x 7 x 7
            nn.Conv2d(128, 256, 3, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: 256 x 4 x 4
            nn.Conv2d(256, 1, 4, 1, 0, bias=False), # Output 1 channel for binary classification
            nn.Sigmoid() # Output probability in [0, 1]
            # Output size: 1 x 1 x 1
        )

    def forward(self, img, labels):
        # Expand labels to match image dimensions for concatenation
        labels_expanded = self.label_emb(labels).view(labels.size(0), num_classes, 1, 1).repeat(1, 1, image_size, image_size)
        # Concatenate image and expanded label
        d_in = torch.cat((img, labels_expanded), 1)
        return self.main(d_in).view(-1, 1) # Flatten output for BCE loss

In [15]:
# --- 5. Initialize models and weights ---
netG = Generator().to(device)
netD = Discriminator().to(device)

# Custom weights initialization for DCGAN
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

Discriminator(
  (label_emb): Embedding(10, 10)
  (main): Sequential(
    (0): Conv2d(11, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (9): Sigmoid()
  )
)

In [16]:
# --- 6. Loss function and Optimizers ---
criterion = nn.BCELoss() # Binary Cross-Entropy Loss
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [17]:
# --- 7. Training Loop ---
print("\nStarting Training Loop...")
for epoch in range(num_epochs):
    for i, (real_images, labels) in enumerate(dataloader):
        # Move data to the selected device (GPU if available)
        real_images = real_images.to(device)
        labels = labels.to(device)
        batch_size = real_images.size(0)

        # Train Discriminator
        netD.zero_grad()
        # Train with real images
        output_real = netD(real_images, labels).view(-1) # Flatten output for criterion
        errD_real = criterion(output_real, torch.ones_like(output_real))
        errD_real.backward()

        # Generate fake images
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
        fake_images = netG(noise, fake_labels)
        # Train with fake images
        output_fake = netD(fake_images.detach(), fake_labels).view(-1) # Detach to prevent G from learning
        errD_fake = criterion(output_fake, torch.zeros_like(output_fake))
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizerD.step()

        # Train Generator
        netG.zero_grad()
        output_gen = netD(fake_images, fake_labels).view(-1)
        errG = criterion(output_gen, torch.ones_like(output_gen)) # Generator wants D to classify fakes as real
        errG.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(dataloader)}] "
                  f"Loss D: {errD.item():.4f} Loss G: {errG.item():.4f}")

    # Optionally save generator model at certain epochs for checkpoints
    # if (epoch + 1) % 10 == 0:
    #     torch.save(netG.state_dict(), f"generator_epoch_{epoch+1}.pth")
    #     print(f"Generator model saved at epoch {epoch+1}")


Starting Training Loop...
Epoch [1/200] Batch [0/469] Loss D: 1.5969 Loss G: 0.7121
Epoch [1/200] Batch [100/469] Loss D: 0.7971 Loss G: 1.4931
Epoch [1/200] Batch [200/469] Loss D: 0.8702 Loss G: 1.4949
Epoch [1/200] Batch [300/469] Loss D: 1.3534 Loss G: 0.6256
Epoch [1/200] Batch [400/469] Loss D: 0.9888 Loss G: 0.9355
Epoch [2/200] Batch [0/469] Loss D: 0.9804 Loss G: 1.2319
Epoch [2/200] Batch [100/469] Loss D: 0.8839 Loss G: 1.0952
Epoch [2/200] Batch [200/469] Loss D: 0.7474 Loss G: 1.3542
Epoch [2/200] Batch [300/469] Loss D: 1.0748 Loss G: 1.0636
Epoch [2/200] Batch [400/469] Loss D: 0.9844 Loss G: 1.1816
Epoch [3/200] Batch [0/469] Loss D: 2.5184 Loss G: 0.8988
Epoch [3/200] Batch [100/469] Loss D: 1.0732 Loss G: 1.0243
Epoch [3/200] Batch [200/469] Loss D: 0.9942 Loss G: 1.2419
Epoch [3/200] Batch [300/469] Loss D: 0.8452 Loss G: 1.4341
Epoch [3/200] Batch [400/469] Loss D: 1.6712 Loss G: 0.8547
Epoch [4/200] Batch [0/469] Loss D: 1.2470 Loss G: 1.4087
Epoch [4/200] Batch [

In [20]:
# --- 8. Save the final trained Generator model ---
model_save_path = "generator_final.pth"
torch.save(netG.state_dict(), model_save_path)
print(f"Training complete. Final generator model saved to {model_save_path}")

Training complete. Final generator model saved to generator_final.pth
