<a href="https://colab.research.google.com/github/AbhirKarande/ImgGenEngine/blob/main/CGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.fc(gen_input)
        img = img.view(img.size(0), 1, 28, 28)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.fc = nn.Sequential(
            nn.Linear(784 + num_classes, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        img = img.view(img.size(0), -1)
        d_in = torch.cat((img, self.label_emb(labels)), -1)
        validity = self.fc(d_in)
        return validity

In [None]:
def train(generator, discriminator, optimizer_G, optimizer_D, dataloader, device, num_epochs, latent_dim, num_classes):
    criterion = nn.BCELoss()

    for epoch in range(num_epochs):
        for i, (imgs, labels) in enumerate(dataloader):
            batch_size = imgs.size(0)
            real_imgs = imgs.to(device)
            labels = labels.to(device)

            # Train Generator
            optimizer_G.zero_grad()

            z = torch.randn(batch_size, latent_dim).to(device)
            gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
            gen_imgs = generator(z, gen_labels)

            validity = discriminator(gen_imgs, gen_labels)
            g_loss = criterion(validity, torch.ones(batch_size, 1).to(device))

            g_loss.backward()
            optimizer_G.step()

            # Train Discriminator
            optimizer_D.zero_grad()

            real_pred = discriminator(real_imgs, labels)
            d_real_loss = criterion(real_pred, torch.ones(batch_size, 1).to(device))

            fake_pred = discriminator(gen_imgs.detach(), gen_labels)
            d_fake_loss = criterion(fake_pred, torch.zeros(batch_size, 1).to(device))

            d_loss = d_real_loss + d_fake_loss

            d_loss.backward()
            optimizer_D.step()

            # Print statistics
            if i % 100 == 0:
                print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]"
                      % (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))

        # Save generated images
        save_image(gen_imgs.data[:25], "images/%d.png" % epoch, nrow=5, normalize=True)
