# Handwritten Digit Generator Web App in Google Colab

**This script will:**
- 1. Install necessary libraries (PyTorch, Streamlit, pyngrok).
- 2. Define and train a Conditional DCGAN on the MNIST dataset.
- 3. Save the trained Generator model.
- 4. Create a Streamlit web application script.

In [1]:
# --- 1. Install necessary libraries ---
# This part needs to be run only once.
!pip install torch torchvision streamlit pillow --quiet

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m53.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m36.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m49.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# --- 2. Device configuration and Hyperparameters ---
# This line will automatically detect and use a CUDA-enabled GPU (like T4) if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters for GAN training
latent_dim = 100
num_classes = 10
image_size = 28
num_epochs = 50 # Reduced for faster demonstration in Colab, consider 100-200 for better quality
batch_size = 128 # Increased for faster training
lr = 0.0002
beta1 = 0.5

Using device: cuda


In [3]:
# --- 3. Data Transformation and Loading ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])

# MNIST Dataset
print("Downloading MNIST dataset...")
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
print("MNIST dataset loaded.")

Downloading MNIST dataset...


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.48MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.46MB/s]

MNIST dataset loaded.





In [4]:
#  --- 4. Generator and Discriminator Architectures ---
# Generator model for generating images from latent vectors and digit labels
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Embedding for conditional input (digit label)
        self.label_emb = nn.Embedding(num_classes, num_classes)

        # Main sequential block of the generator
        self.main = nn.Sequential(
            # Input: (latent_dim + num_classes) x 1 x 1 (concatenated noise and label)
            # From 1x1 to 7x7
            nn.ConvTranspose2d(latent_dim + num_classes, 256, 7, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State size: 256 x 7 x 7
            # From 7x7 to 14x14
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State size: 128 x 14 x 14
            # From 14x14 to 28x28
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False), # Output 1 channel for grayscale
            nn.Tanh() # Output pixel values in [-1, 1]
            # Output size: 1 x 28 x 28
        )

    def forward(self, noise, labels):
        # Combine noise and label embedding
        # Label embedding is converted to a vector and concatenated with noise
        gen_input = torch.cat((self.label_emb(labels), noise.view(noise.size(0), -1)), -1)
        # Reshape for ConvTranspose2d (batch_size, channels, 1, 1)
        gen_input = gen_input.view(gen_input.size(0), gen_input.size(1), 1, 1)
        return self.main(gen_input)

# Discriminator model for distinguishing real from fake images
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Embedding for conditional input (digit label)
        self.label_emb = nn.Embedding(num_classes, num_classes)

        # Main sequential block of the discriminator
        self.main = nn.Sequential(
            # Input: (1 + num_classes) x 28 x 28 (after concatenating label)
            nn.Conv2d(1 + num_classes, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: 64 x 14 x 14
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: 128 x 7 x 7
            nn.Conv2d(128, 256, 3, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: 256 x 4 x 4
            nn.Conv2d(256, 1, 4, 1, 0, bias=False), # Output 1 channel for binary classification
            nn.Sigmoid() # Output probability in [0, 1]
            # Output size: 1 x 1 x 1
        )

    def forward(self, img, labels):
        # Expand labels to match image dimensions for concatenation
        labels_expanded = self.label_emb(labels).view(labels.size(0), num_classes, 1, 1).repeat(1, 1, image_size, image_size)
        # Concatenate image and expanded label
        d_in = torch.cat((img, labels_expanded), 1)
        return self.main(d_in).view(-1, 1) # Flatten output for BCE loss

In [5]:
# --- 5. Initialize models and weights ---
netG = Generator().to(device)
netD = Discriminator().to(device)

# Custom weights initialization for DCGAN
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

Discriminator(
  (label_emb): Embedding(10, 10)
  (main): Sequential(
    (0): Conv2d(11, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (9): Sigmoid()
  )
)

In [6]:
# --- 6. Loss function and Optimizers ---
criterion = nn.BCELoss() # Binary Cross-Entropy Loss
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [7]:
# --- 7. Training Loop ---
print("\nStarting Training Loop...")
for epoch in range(num_epochs):
    for i, (real_images, labels) in enumerate(dataloader):
        # Move data to the selected device (GPU if available)
        real_images = real_images.to(device)
        labels = labels.to(device)
        batch_size = real_images.size(0)

        # Train Discriminator
        netD.zero_grad()
        # Train with real images
        output_real = netD(real_images, labels).view(-1) # Flatten output for criterion
        errD_real = criterion(output_real, torch.ones_like(output_real))
        errD_real.backward()

        # Generate fake images
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
        fake_images = netG(noise, fake_labels)
        # Train with fake images
        output_fake = netD(fake_images.detach(), fake_labels).view(-1) # Detach to prevent G from learning
        errD_fake = criterion(output_fake, torch.zeros_like(output_fake))
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizerD.step()

        # Train Generator
        netG.zero_grad()
        output_gen = netD(fake_images, fake_labels).view(-1)
        errG = criterion(output_gen, torch.ones_like(output_gen)) # Generator wants D to classify fakes as real
        errG.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(dataloader)}] "
                  f"Loss D: {errD.item():.4f} Loss G: {errG.item():.4f}")

    # Optionally save generator model at certain epochs for checkpoints
    # if (epoch + 1) % 10 == 0:
    #     torch.save(netG.state_dict(), f"generator_epoch_{epoch+1}.pth")
    #     print(f"Generator model saved at epoch {epoch+1}")


Starting Training Loop...
Epoch [1/50] Batch [0/469] Loss D: 1.5644 Loss G: 1.3757
Epoch [1/50] Batch [100/469] Loss D: 0.9501 Loss G: 1.1153
Epoch [1/50] Batch [200/469] Loss D: 1.1492 Loss G: 1.1713
Epoch [1/50] Batch [300/469] Loss D: 0.8305 Loss G: 1.6863
Epoch [1/50] Batch [400/469] Loss D: 1.2643 Loss G: 1.3840
Epoch [2/50] Batch [0/469] Loss D: 0.7887 Loss G: 1.2657
Epoch [2/50] Batch [100/469] Loss D: 0.9160 Loss G: 1.5377
Epoch [2/50] Batch [200/469] Loss D: 0.7424 Loss G: 1.2741
Epoch [2/50] Batch [300/469] Loss D: 1.1483 Loss G: 1.8136
Epoch [2/50] Batch [400/469] Loss D: 0.8067 Loss G: 1.0559
Epoch [3/50] Batch [0/469] Loss D: 1.1811 Loss G: 0.9084
Epoch [3/50] Batch [100/469] Loss D: 0.9174 Loss G: 1.0582
Epoch [3/50] Batch [200/469] Loss D: 1.4335 Loss G: 0.7474
Epoch [3/50] Batch [300/469] Loss D: 0.7952 Loss G: 1.3036
Epoch [3/50] Batch [400/469] Loss D: 1.3171 Loss G: 1.1933
Epoch [4/50] Batch [0/469] Loss D: 1.3376 Loss G: 1.3377
Epoch [4/50] Batch [100/469] Loss D: 

In [8]:
# --- 8. Save the final trained Generator model ---
model_save_path = "generator_final.pth"
torch.save(netG.state_dict(), model_save_path)
print(f"Training complete. Final generator model saved to {model_save_path}")

Training complete. Final generator model saved to generator_final.pth


In [10]:
# --- 9. Create Streamlit Application Script (app.py) ---
# We write the Streamlit app content to a file to be run by 'streamlit run'
streamlit_app_code = f"""
import streamlit as st
import torch
import torch.nn as nn
from PIL import Image
import numpy as np

# Device configuration (must match training device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model architecture (must be identical to the trained model's Generator)
latent_dim = {latent_dim}
num_classes = {num_classes}
image_size = {image_size}

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim + num_classes, 256, 7, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise.view(noise.size(0), -1)), -1)
        gen_input = gen_input.view(gen_input.size(0), gen_input.size(1), 1, 1)
        return self.main(gen_input)

# Load the pre-trained generator model
@st.cache_resource # Cache the model loading for efficiency
def load_generator_model(model_path="{model_save_path}"):
    generator = Generator().to(device)
    try:
        generator.load_state_dict(torch.load(model_path, map_location=device))
        generator.eval() # Set to evaluation mode for inference
        st.success(f"Generator model loaded successfully from {{model_path}}")
    except FileNotFoundError:
        st.error(f"Error: Model file not found at {{model_path}}. Please ensure the model is trained and saved.")
        st.stop() # Stop execution if model is not found
    return generator

generator_model = load_generator_model()

# Function to generate images for a given digit
def generate_digit_images(digit, num_images=5):
    with torch.no_grad(): # Disable gradient calculations for inference
        # Generate random noise vectors for diversity
        noise = torch.randn(num_images, latent_dim, 1, 1, device=device)
        # Create labels for the specified digit (all images will be of this digit)
        labels = torch.full((num_images,), digit, dtype=torch.long, device=device)

        # Generate images using the model
        generated_images = generator_model(noise, labels).cpu()

        # Denormalize images from [-1, 1] to [0, 1] for display
        generated_images = (generated_images + 1) / 2

        # Convert to PIL images for display in Streamlit
        pil_images = []
        for i in range(num_images):
            img_tensor = generated_images[i].squeeze(0) # Remove channel dimension (1, 28, 28) -> (28, 28)
            img_np = img_tensor.numpy() # Convert to NumPy array
            # Scale to 0-255 and convert to uint8 for PIL
            img_pil = Image.fromarray((img_np * 255).astype(np.uint8), 'L') # 'L' mode for grayscale
            pil_images.append(img_pil)
    return pil_images

# --- Streamlit UI ---
st.set_page_config(layout="centered", page_title="Handwritten Digit Generator")

st.title("Handwritten Digit Generator")
st.write("Generate 5 diverse images of a handwritten digit (0-9) similar to the MNIST dataset.")

# User input for digit selection
selected_digit = st.selectbox("Select a digit to generate:", options=list(range(10)))

# Button to trigger image generation
if st.button("Generate Images"):
    st.write(f"Generating 5 images for digit: **{{selected_digit}}**...")
    with st.spinner('Generating images...'): # Show a spinner while images are being generated
        images = generate_digit_images(selected_digit)

    # Display images in a grid format
    cols = st.columns(5) # Create 5 columns for the images
    for i, img in enumerate(images):
        with cols[i]:
            st.image(img, caption=f"Digit {{selected_digit}} - Image {{i+1}}", use_column_width=True)
    st.success("Images generated!")

"""

# Write the Streamlit app code to a file
with open("streamlit_app.py", "w") as f:
    f.write(streamlit_app_code)
print("Streamlit app code written to app.py")

Streamlit app code written to app.py
