<a href="https://colab.research.google.com/github/Vishal-113/Transformers/blob/main/Digit_Class_Controlled_Image_Generation_with_Conditional_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Step 1: Load MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist_data, batch_size=64, shuffle=True)

# Step 2: Define the Generator
class Generator(nn.Module):
    def __init__(self, noise_dim, num_classes, img_dim):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(noise_dim + num_classes, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_embedding = self.label_embedding(labels)
        x = torch.cat([noise, label_embedding], dim=1)
        return self.model(x)

# Step 3: Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_dim):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(img_dim + num_classes, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        label_embedding = self.label_embedding(labels)
        x = torch.cat([img, label_embedding], dim=1)
        return self.model(x)

# Step 4: Initialize Models, Loss, and Optimizers
noise_dim = 100
img_dim = 28 * 28
num_classes = 10

generator = Generator(noise_dim, num_classes, img_dim)
discriminator = Discriminator(num_classes, img_dim)
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Step 5: Training Loop
epochs = 50
for epoch in range(epochs):
    for real_imgs, labels in dataloader:
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.view(batch_size, -1)
        labels = labels.to(torch.int64)

        # Train Discriminator
        noise = torch.randn(batch_size, noise_dim)
        fake_labels = torch.randint(0, num_classes, (batch_size,))
        fake_imgs = generator(noise, fake_labels)

        real_validity = discriminator(real_imgs, labels)
        fake_validity = discriminator(fake_imgs.detach(), fake_labels)

        real_loss = criterion(real_validity, torch.ones(batch_size, 1))
        fake_loss = criterion(fake_validity, torch.zeros(batch_size, 1))
        d_loss = real_loss + fake_loss

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        gen_validity = discriminator(fake_imgs, fake_labels)
        g_loss = criterion(gen_validity, torch.ones(batch_size, 1))

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item()} | G Loss: {g_loss.item()}")

# Step 6: Visualization of Generated Digits
generator.eval()
rows, cols = 10, 10
fig, axs = plt.subplots(rows, cols, figsize=(10, 10))
for i in range(rows):
    label = torch.full((cols,), i, dtype=torch.int64)
    noise = torch.randn(cols, noise_dim)
    generated_imgs = generator(noise, label).view(-1, 1, 28, 28).detach()
    for j in range(cols):
        axs[i, j].imshow(generated_imgs[j][0], cmap='gray')
        axs[i, j].axis('off')
plt.show()
