## MNIST Handwritten digit generation using Conditional Generative Adversarial Networks (Conditional GANs)

In [1]:
! pip install torch

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

# Hyperparameters
epochs = 50
batch_size = 128
lr = 0.0002
z_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==================== Generator ====================
class Generator(nn.Module):
    def __init__(self, z_dim=100, label_dim=10):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(label_dim, label_dim)
        self.model = nn.Sequential(
            nn.Linear(z_dim + label_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z, labels):
        label_input = self.label_emb(labels)
        x = torch.cat([z, label_input], dim=1)
        img = self.model(x)
        return img.view(-1, 1, 28, 28)

# ==================== Discriminator ====================
class Discriminator(nn.Module):
    def __init__(self, label_dim=10):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(label_dim, label_dim)
        self.model = nn.Sequential(
            nn.Linear(784 + label_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        img_flat = img.view(img.size(0), -1)
        label_input = self.label_emb(labels)
        x = torch.cat([img_flat, label_input], dim=1)
        validity = self.model(x)
        return validity

# ==================== Prepare Dataset ====================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# ==================== Initialize Models ====================
generator = Generator(z_dim=z_dim).to(device)
discriminator = Discriminator().to(device)

# Loss and Optimizer
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# ==================== Training Loop ====================
for epoch in range(epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size_curr = imgs.size(0)
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # Real and fake labels
        real = torch.ones(batch_size_curr, 1).to(device)
        fake = torch.zeros(batch_size_curr, 1).to(device)

        # === Train Generator ===
        z = torch.randn(batch_size_curr, z_dim).to(device)
        gen_labels = torch.randint(0, 10, (batch_size_curr,)).to(device)
        gen_imgs = generator(z, gen_labels)
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = criterion(validity, real)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # === Train Discriminator ===
        real_pred = discriminator(real_imgs, labels)
        d_real_loss = criterion(real_pred, real)

        fake_pred = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = criterion(fake_pred, fake)

        d_loss = (d_real_loss + d_fake_loss) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Print log
        if i % 100 == 0:
            print(f"[Epoch {epoch+1}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

# ==================== Save Trained Generator ====================
os.makedirs("models", exist_ok=True)
torch.save(generator.state_dict(), "generator.pth")
print("✅ Generator saved to 'generator.pth'")

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.12MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 135kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.28MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.92MB/s]


[Epoch 1/50] [Batch 0/469] [D loss: 0.7027] [G loss: 0.7059]
[Epoch 1/50] [Batch 100/469] [D loss: 0.2328] [G loss: 2.3878]
[Epoch 1/50] [Batch 200/469] [D loss: 0.5293] [G loss: 0.5677]
[Epoch 1/50] [Batch 300/469] [D loss: 0.5890] [G loss: 0.7018]
[Epoch 1/50] [Batch 400/469] [D loss: 0.5416] [G loss: 1.2125]
[Epoch 2/50] [Batch 0/469] [D loss: 0.4786] [G loss: 1.0140]
[Epoch 2/50] [Batch 100/469] [D loss: 0.4451] [G loss: 1.3697]
[Epoch 2/50] [Batch 200/469] [D loss: 0.3464] [G loss: 1.1266]
[Epoch 2/50] [Batch 300/469] [D loss: 0.4900] [G loss: 1.0514]
[Epoch 2/50] [Batch 400/469] [D loss: 0.4474] [G loss: 0.8960]
[Epoch 3/50] [Batch 0/469] [D loss: 0.4363] [G loss: 0.9887]
[Epoch 3/50] [Batch 100/469] [D loss: 0.4613] [G loss: 0.9483]
[Epoch 3/50] [Batch 200/469] [D loss: 0.5714] [G loss: 0.5471]
[Epoch 3/50] [Batch 300/469] [D loss: 0.5680] [G loss: 0.4758]
[Epoch 3/50] [Batch 400/469] [D loss: 0.3731] [G loss: 1.1556]
[Epoch 4/50] [Batch 0/469] [D loss: 0.3831] [G loss: 1.5514]
