<a href="https://colab.research.google.com/github/TheDarkNight21/routellm-assignment/blob/main/cgan_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
!pip install tensorboardX



In [13]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder, CIFAR10
from torchvision import transforms
from torch import autograd
from torch.autograd import Variable
from torchvision.utils import make_grid
from tensorboardX import SummaryWriter
import torch.nn.utils.spectral_norm as spectral_norm

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('torch version:',torch.__version__)
print('device:', device)

torch version: 2.8.0+cu126
device: cpu


pipeline of pre processing steps applied to each image

In [15]:
transform = transforms.Compose([
    transforms.ToTensor(), # converts a PIL image/numpy array to torch.FloatTensor with shape (Channel, Height, Width)
    transforms.Normalize([0.5], [0.5], [0.5]) # applies per channel normaliztion so values range from [-1, 1] instead of [0,1] -- good since we will be using tanh as final activation function
    # need [0.5] 3 times since we have 3 channels (dataset is in RGB)
    ])

load data / parameters

In [16]:
img_size = 32 # Image size
batch_size = 64  # Batch size

data_loader = torch.utils.data.DataLoader(CIFAR10('data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)

projection discriminator

this is how it will look like:

$D(x, y) = h (f(x)) + <f(x), e(y)>$
where, $h (f(x))$ determines how real the image looks and
$<f(x), e(y)>$ determines if the features match the label assigned to the image

In [17]:
# nn.Module makes our lives easier; it is a base class that tracks all parameters and needs a forward() method for forward prop
class conditionalDiscriminator(nn.Module):
  def __init__(self):
    super().__init__()

    # self.___ are all called modules

    self.model = nn.Sequential(
        spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1)), # shape: [N, 64, 16, 16]; N is batch size
        nn.LeakyReLU(0.2, inplace=True),
        spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)), # shape: [N, 128, 8, 8]; N is batch size
        nn.LeakyReLU(0.2, inplace=True),
        spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)), # shape: [N, 256, 4, 4]; N is batch size
        nn.LeakyReLU(0.2, inplace=True)
    )

    self.pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    self.linear = spectral_norm(nn.Linear(in_features=256, out_features=1)) # real/fake logit output per image

    self.embedding = spectral_norm(nn.Embedding(num_embeddings=10, embedding_dim=256)) # embedding for labels

  def forward(self, x, labels):
    features = self.model(x) # [N, 256, 4, 4]
    features = self.pooling(features) # [N, 256, 1, 1]
    features = torch.flatten(features, start_dim=1) # [N, 256] ; now we have f(x)

    hfx = self.linear(features) # [N, 1] ; h(f(x)) -- the single score; the higher, the better
    ey = self.embedding(labels) # [N, 256] ; e(y) -- converts each class label into learnable 256-dim vector

    dot_product = torch.sum(features * ey, dim=1, keepdim=True) # computes dot product of f(x) - (N, 256) and e(y) - (N, 256)

    logit = hfx + dot_product # [N, 1]

    return logit

    features = self.linear()
    pass