In [None]:
# conditional_gan_training.ipynb

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
z_dim = 100 # Dimension of the noise vector
num_classes = 10 # For MNIST digits 0-9
image_dim = 28 * 28 # 784 for MNIST (28x28 grayscale)
batch_size = 64
num_epochs = 50 # Adjust as needed. Start with 50-100, observe results.
lr = 0.0002
b1 = 0.5 # Adam: decay of first order momentum of gradients
b2 = 0.999 # Adam: decay of second order momentum of gradients

# MNIST Dataset loading and transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])

mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)

# Conditional Generator Definition
class ConditionalGenerator(nn.Module):
    def __init__(self, z_dim, num_classes, img_dim):
        super().__init__()
        # Embedding layer to convert class labels (0-9) into a dense vector
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.main = nn.Sequential(
            nn.Linear(z_dim + num_classes, 256), # Concatenate noise and label embedding
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_dim),
            nn.Tanh() # Output in range [-1, 1]
        )

    def forward(self, noise, labels):
        # Ensure labels are of type torch.long for nn.Embedding
        c = self.label_emb(labels.long())
        input_tensor = torch.cat([noise, c], 1) # Concatenate along dimension 1 (columns)
        return self.main(input_tensor).view(-1, 1, 28, 28) # Reshape to 1 channel, 28x28 image

# Conditional Discriminator Definition
class ConditionalDiscriminator(nn.Module):
    def __init__(self, num_classes, img_dim):
        super().__init__()
        # Embedding layer for class labels
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.main = nn.Sequential(
            nn.Linear(img_dim + num_classes, 1024), # Concatenate flattened image and label embedding
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3), # Dropout for regularization
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid() # Output a probability between 0 and 1
        )

    def forward(self, img, labels):
        img_flat = img.view(img.size(0), -1) # Flatten the image
        # Ensure labels are of type torch.long for nn.Embedding
        c = self.label_emb(labels.long())
        input_tensor = torch.cat([img_flat, c], 1)
        return self.main(input_tensor)

# Initialize Conditional Generator and Discriminator
generator = ConditionalGenerator(z_dim, num_classes, image_dim).to(device)
discriminator = ConditionalDiscriminator(num_classes, image_dim).to(device)

# Loss function and optimizers
criterion = nn.BCELoss() # Binary Cross Entropy Loss for GANs
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))

# Training Loop for CGAN
print("Starting Conditional GAN Training Loop...")
for epoch in range(num_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        # Configure input
        real_imgs = imgs.to(device)
        real_labels_tensor = labels.to(device) # Labels corresponding to real images

        # Adversarial ground truths (for discriminator and generator loss)
        real_gan_labels = torch.ones(real_imgs.size(0), 1).to(device)
        fake_gan_labels = torch.zeros(real_imgs.size(0), 1).to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real images
        output_real = discriminator(real_imgs, real_labels_tensor)
        d_loss_real = criterion(output_real, real_gan_labels)

        # Measure discriminator's ability to classify fake images
        z = torch.randn(real_imgs.size(0), z_dim).to(device) # Noise vector
        # Generate random labels for fake images for the discriminator's fake input
        # This helps the discriminator learn to differentiate all types of fake digits
        fake_labels_tensor_D = torch.randint(0, num_classes, (real_imgs.size(0),)).to(device)
        fake_imgs = generator(z, fake_labels_tensor_D)

        # Detach fake_imgs and fake_labels_tensor_D to prevent gradients from flowing to Generator
        output_fake = discriminator(fake_imgs.detach(), fake_labels_tensor_D.detach())
        d_loss_fake = criterion(output_fake, fake_gan_labels)

        # Total discriminator loss (average of real and fake loss)
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Generate fake images
        z = torch.randn(real_imgs.size(0), z_dim).to(device)
        # Generate random labels for the generator's target. The generator tries to make fake images
        # that the discriminator will classify as 'real' with these labels.
        gen_target_labels = torch.randint(0, num_classes, (real_imgs.size(0),)).to(device)
        fake_imgs = generator(z, gen_target_labels)

        # Loss measures generator's ability to fool the discriminator
        # The generator wants the discriminator to output '1' (real) for its fake images
        output = discriminator(fake_imgs, gen_target_labels)
        g_loss = criterion(output, real_gan_labels)

        g_loss.backward()
        optimizer_G.step()

        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}] "
                f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}"
            )

    # Save generated images for visual inspection (conditional on a specific digit)
    # This block generates 5 images of a chosen digit (e.g., '7') to see training progress.
    if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1: # Also save on the last epoch
        with torch.no_grad():
            test_digit = 7 # Example: generate images of digit 7
            fixed_noise = torch.randn(5, z_dim).to(device)
            # Create labels for the specific digit we want to generate
            fixed_labels = torch.full((5,), test_digit, dtype=torch.long).to(device)
            generated_images = generator(fixed_noise, fixed_labels).cpu().detach()

            # Denormalize images from [-1, 1] to [0, 1] for plotting
            generated_images = (generated_images + 1) / 2

            fig, axes = plt.subplots(1, 5, figsize=(10, 2))
            for img_idx, img in enumerate(generated_images):
                axes[img_idx].imshow(img.squeeze(), cmap='gray')
                axes[img_idx].axis('off')
            plt.suptitle(f'Generated Digit {test_digit} at Epoch {epoch+1}')
            plt.show()

# Save the trained generator model's state dictionary
# This is the file you'll need  for your web app
torch.save(generator.state_dict(), 'conditional_generator_mnist.pth')
print("Training complete. Conditional Generator model saved as conditional_generator_mnist.pth")