In [2]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision
import torchvision.utils as vutils

In [None]:
class ConditionalGenerator(nn.Module):
    """
    Generator network for Conditional GAN. Takes in noise vector and class labels to generate images.
    """
    def __init__(self, noise_dim, num_classes, img_channels, embedding_dim):
        super(ConditionalGenerator, self).__init__()

        self.label_embed = nn.Embedding(num_classes, embedding_dim)  # Embed class labels

        # Fully connected and transpose conv layers for upsampling
        self.model = nn.Sequential(
            nn.Linear(noise_dim + embedding_dim, 1024*4*4),           # Output: (batch, 1024*4*4)
            nn.BatchNorm1d(1024*4*4),
            nn.ReLU(True),
            nn.Unflatten(1, (1024, 4, 4)),                             # Output: (batch, 1024, 4, 4)

            nn.ConvTranspose2d(1024, 512, 4, 2, 1),                    # Output: (batch, 512, 8, 8)
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),                     # Output: (batch, 256, 16, 16)
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),                     # Output: (batch, 128, 32, 32)
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),                      # Output: (batch, 64, 64, 64)
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),            # Output: (batch, 3, 128, 128)
            nn.Tanh()
        )

    def forward(self, class_labels, noise):
        label_embed = self.label_embed(class_labels)                  # Output: (batch, embedding_dim)
        x = torch.cat((noise, label_embed), dim=1)                   # Concatenate noise + labels
        x = self.model(x)
        return x

In [None]:
class ConditionalDiscriminator(nn.Module):
    """
    Discriminator for Conditional GAN. Takes images and labels to output probability of realness.
    """
    def __init__(self, num_classes, img_channels, embedding_dim):
        super(ConditionalDiscriminator, self).__init__()

        self.label_embed = nn.Embedding(num_classes, embedding_dim)

        self.model = nn.Sequential(
            nn.Conv2d(embedding_dim + img_channels, 64, 4, 2, 1),     # Output: (batch, 64, 128, 128)
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),                              # Output: (batch, 128, 64, 64)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),                             # Output: (batch, 256, 32, 32)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, 4, 2, 1),                             # Output: (batch, 512, 16, 16)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 1024, 4, 2, 1),                            # Output: (batch, 1024, 8, 8)
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),

            nn.Conv2d(1024, 1, 4, 1, 0),                              # Output: (batch, 1, 5, 5)
            nn.Sigmoid()
        )

    def forward(self, class_labels, image):
        label_embed = self.label_embed(class_labels).unsqueeze(2).unsqueeze(3)  # Output: (batch, embedding_dim, 1, 1)
        label_embed = label_embed.expand(-1, -1, image.size(2), image.size(3)) # Expand to (batch, embedding_dim, H, W)
        x = torch.cat((image, label_embed), dim=1)
        x = self.model(x)
        return x

In [None]:
def train_model(Generator, Discriminator, G_optimizer, D_optimizer, criterion, device, train_dataloader, num_classes, noise_dim, num_epoch):
    """
    Trains the Conditional GAN model.
    """
    D_losses, G_losses = [], []

    print("Training started...\n")
    for epoch in range(num_epoch):
        total_D_loss, total_G_loss = 0, 0

        for i, (real_images, labels) in enumerate(train_dataloader):
            real_images = real_images.to(device)
            labels = labels.to(device)
            batch_size = real_images.size(0)

            # Train Discriminator
            noise = torch.randn((batch_size, noise_dim)).to(device)
            fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
            fake_images = Generator(fake_labels, noise)

            D_real = Discriminator(labels, real_images)
            D_fake = Discriminator(fake_labels.detach(), fake_images.detach())

            real_loss = criterion(D_real, torch.ones_like(D_real))
            fake_loss = criterion(D_fake, torch.zeros_like(D_fake))
            D_loss = real_loss + fake_loss

            D_optimizer.zero_grad()
            D_loss.backward()
            D_optimizer.step()
            total_D_loss += D_loss.item()

            # Train Generator
            noise = torch.randn((batch_size, noise_dim)).to(device)
            fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
            fake_images = Generator(fake_labels, noise)

            D_fake = Discriminator(fake_labels, fake_images)
            G_loss = criterion(D_fake, torch.ones_like(D_fake))

            G_optimizer.zero_grad()
            G_loss.backward()
            G_optimizer.step()
            total_G_loss += G_loss.item()

            if i % 100 == 0:
                print(f"Batch {i}/{len(train_dataloader)}: D_loss = {D_loss.item():.4f}, G_loss = {G_loss.item():.4f}")

        avg_D_loss = total_D_loss / len(train_dataloader)
        avg_G_loss = total_G_loss / len(train_dataloader)
        D_losses.append(avg_D_loss)
        G_losses.append(avg_G_loss)

        print(f"[Epoch {epoch+1}/{num_epoch}] ➤ D_loss: {avg_D_loss:.4f}, G_loss: {avg_G_loss:.4f}")

    return Generator, D_losses, G_losses

In [None]:
def evaluate_model(Generator, noise_dim, device):
    """
    Generates and visualizes samples using the trained generator.
    """
    print("Generating images with trained Generator...\n")
    Generator.eval()

    noise = torch.randn((101, noise_dim)).to(device)
    labels = torch.arange(0, 101).to(device)
    generated_images = Generator(labels, noise)

    vutils.save_image(generated_images, "cgan_samples.png", nrow=5, normalize=True)
    generated_images = generated_images * 0.5 + 0.5

    grid = vutils.make_grid(generated_images.detach().cpu(), nrow=5, padding=2, normalize=False)
    plt.figure(figsize=(10, 10))
    plt.axis("off")
    plt.title("Generated Images")
    plt.imshow(grid.permute(1, 2, 0))
    plt.show()

In [None]:
if __name__ == "__main__":
    noise_dim = 256
    num_classes = 101
    embedding_dim = 100
    batch_size = 64
    G_lr = 2e-4
    D_lr = 1e-4
    num_epoch = 100
    device = "cuda" if torch.cuda.is_available() else "cpu"

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * 3, [0.5] * 3)
    ])
    
    print("Loading dataset...\n")
    train_data = torchvision.datasets.Food101(root="./data", split="train", download=True, transform=transform)
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    Generator = ConditionalGenerator(noise_dim, num_classes, img_channels=3, embedding_dim=embedding_dim).to(device)
    Discriminator = ConditionalDiscriminator(num_classes, img_channels=3, embedding_dim=embedding_dim).to(device)

    G_optimizer = torch.optim.Adam(Generator.parameters(), lr=G_lr, betas=(0.5, 0.999))
    D_optimizer = torch.optim.Adam(Discriminator.parameters(), lr=D_lr, betas=(0.5, 0.999))
    criterion = nn.BCELoss()

    trained_generator, D_losses, G_losses = train_model(
        Generator,
        Discriminator,
        G_optimizer,
        D_optimizer,
        criterion,
        device,
        train_dataloader,
        num_classes,
        noise_dim,
        num_epoch
    )

    # Plot training losses
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 8))

    axes[0].plot(range(1, num_epoch + 1), D_losses, label="Discriminator_losses", color='blue')
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Discriminator Loss vs Epoch")
    axes[0].legend()

    axes[1].plot(range(1, num_epoch + 1), G_losses, label="Generator_losses", color='red')
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Loss")
    axes[1].set_title("Generator Loss vs Epoch")
    axes[1].legend()

    plt.tight_layout()
    plt.show()

    evaluate_model(trained_generator, noise_dim, device)


