<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Conditional_GAN_(cGAN)_for_Image_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

class Generator(nn.Module):
    def __init__(self, z_dim, label_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim + label_dim, 128),
            nn.ReLU(),
            nn.Linear(128, img_dim),
            nn.Tanh()
        )

    def forward(self, z, labels):
        input = torch.cat((z, labels), dim=1)
        return self.model(input)

class Discriminator(nn.Module):
    def __init__(self, img_dim, label_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim + label_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        input = torch.cat((img, labels), dim=1)
        return self.model(input)

# Hyperparameters
z_dim = 100
img_dim = 28 * 28
label_dim = 10
batch_size = 64
lr = 0.0002

# Data loading and transformation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST('.', transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model initialization
gen = Generator(z_dim, label_dim, img_dim)
disc = Discriminator(img_dim, label_dim)
criterion = nn.BCELoss()
opt_gen = optim.Adam(gen.parameters(), lr=lr)
opt_disc = optim.Adam(disc.parameters(), lr=lr)

# Training loop
for epoch in range(50):
    for real, labels in dataloader:
        batch_size = real.size(0)
        real = real.view(batch_size, -1)
        labels_onehot = torch.nn.functional.one_hot(labels, label_dim).float()

        # Train Discriminator
        noise = torch.randn(batch_size, z_dim)
        fake = gen(noise, labels_onehot)
        disc_real = disc(real, labels_onehot)
        disc_fake = disc(fake.detach(), labels_onehot)
        loss_disc = (criterion(disc_real, torch.ones_like(disc_real)) +
                     criterion(disc_fake, torch.zeros_like(disc_fake))) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train Generator
        output = disc(fake, labels_onehot)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

    print(f'Epoch [{epoch+1}/50] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}')