<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/_Implementation_of_the_Conditional_GAN_(CGAN)_for_generating_MNIST_images.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Hyperparameters
latent_dim = 100  # Dimensionality of the noise vector
num_classes = 10  # Number of classes (for MNIST, digits 0-9)
img_size = 28     # Image size (28x28 for MNIST)
channels = 1      # Grayscale images

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_size * img_size * channels),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((noise, self.label_embedding(labels)), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), channels, img_size, img_size)
        return img

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size * channels + num_classes, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
adversarial_loss = nn.BCELoss()

# Transform and DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataloader = DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=64, shuffle=True
)

# Training loop
n_epochs = 200
for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        imgs, labels = imgs.to(device), labels.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()

        # Real images
        real_imgs = imgs.view(imgs.size(0), -1)
        real_validity = discriminator(real_imgs, labels)
        real_loss = adversarial_loss(real_validity, torch.ones(imgs.size(0), 1).to(device))

        # Fake images
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (imgs.size(0),)).to(device)
        fake_imgs = generator(z, gen_labels)
        fake_validity = discriminator(fake_imgs.view(imgs.size(0), -1), gen_labels)
        fake_loss = adversarial_loss(fake_validity, torch.zeros(imgs.size(0), 1).to(device))

        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()

        # Generate fake images
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (imgs.size(0),)).to(device)
        gen_imgs = generator(z, gen_labels)
        validity = discriminator(gen_imgs.view(imgs.size(0), -1), gen_labels)
        g_loss = adversarial_loss(validity, torch.ones(imgs.size(0), 1).to(device))

        g_loss.backward()
        optimizer_G.step()

    # Print progress
    print(f"Epoch [{epoch}/{n_epochs}]  Loss D: {d_loss.item()}, loss G: {g_loss.item()}")

# Function to sample and display images
def sample_images(n_row=5):
    z = torch.randn(n_row ** 2, latent_dim).to(device)
    labels = torch.randint(0, num_classes, (n_row ** 2,)).to(device)
    gen_imgs = generator(z, labels).detach().cpu()
    gen_imgs = gen_imgs.view(-1, channels, img_size, img_size)

    fig, axes = plt.subplots(n_row, n_row, figsize=(10, 10))
    count = 0
    for i in range(n_row):
        for j in range(n_row):
            axes[i, j].imshow(gen_imgs[count].squeeze(), cmap='gray')
            axes[i, j].axis('off')
            count += 1
    plt.show()

# Display generated samples
sample_images()