# cGAN (conditional GAN) implementation Using Pytorch

## Implementation Starts here

In [43]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils

from tqdm.auto import tqdm

### Non-Convolutional GAN

In [79]:
# Generator

class Generator(nn.Module):
    def __init__(self, input_size_z = 100, input_size_condition= 100, hidden_size = 256, output_size = 784, layers = 1,leaky = 0.2, device = 'cuda'):
        super().__init__()

        self.device = device

        self.init_layer = nn.Sequential(
            nn.Linear(input_size_z +input_size_condition, 800, device = self.device),
            nn.LeakyReLU(leaky),
        )

        self.combine = nn.Sequential(
            nn.Linear(800, hidden_size, device = self.device),
            nn.LeakyReLU(leaky),
        )

        self.layer = nn.ModuleList([nn.Sequential(
            nn.Linear(hidden_size, hidden_size, device = self.device),
            nn.LeakyReLU(leaky),
        ) for _ in range(layers-1)])

        self.final = nn.Sequential(
            nn.Linear(hidden_size, output_size, device = self.device),
            nn.Tanh()
        )

    def forward(self, z, y):
        """
        z: the vector 
        y: the label for the vector
        """
        z = z.to(self.device)
        y = y.to(self.device)

        combined = self.init_layer(torch.cat((z, y), dim=1))
        combined = self.combine(combined)

        for layer in self.layer:
            combined = layer(combined)
            
        return self.final(combined) # logits 

In [4]:
def test_generator():
    gen = Generator(input_size_z=100, input_size_condition=10, output_size=784, layers=3, device='cuda')
    noise = torch.randn(1, 100)
    label = torch.randn(1, 10)

    output = gen(noise, label)
    assert output.shape == (1, 784)
    print("Generator test passed")
test_generator()

Generator test passed


In [37]:
# Discriminator - not using maxout because of instability issues

class Discriminator(nn.Module):
    def __init__(self, d_in =784, d_label = 10, hidden_size=256, d_out =1, leaky = 0.2, dropout = 0.5, device='cuda'):
        super().__init__()

        self.device = device

        self.dropout = nn.Dropout(dropout)

        self.label_embed = nn.Linear(d_label, hidden_size, device = self.device)

        self.model = nn.Sequential(
            nn.Linear(d_in + hidden_size, hidden_size, device = self.device),
            nn.LeakyReLU(leaky),
            nn.Linear(hidden_size, hidden_size, device = self.device),
            nn.LeakyReLU(leaky),
            nn.Linear(hidden_size, d_out, device = self.device)
        )

        
    def forward(self, x, y):
        x = x.to(self.device)

        y = y.to(self.device, dtype=torch.float32)
        
        y = self.label_embed(y)

        combined = torch.cat((x,y), dim =1)

        logits = self.dropout(self.model(combined))

        return logits # torch.sigmoid(combined) or torch.tanh(combined) - used with loss

In [33]:
# Unit testing for Discriminator

def test_discriminator():
   disc = Discriminator(d_in = 784, d_label = 10, hidden_size=256, d_out =1, leaky = 0.2, dropout = 0.5, device='cuda')
   x = torch.randn(1, 784)
   y = torch.randn(1, 10)

   output = disc(x, y)
   assert output.shape == (1, 1)
   print("Discriminator test passed")
test_discriminator()

torch.Size([1, 784]) torch.Size([1, 10]) torch.float32 torch.float32
Discriminator test passed


#### Training and Testing Loops

In [81]:
#Hyperparameters
DISC_LR= 0.001
GEN_LR= 0.1
MIN_LR = 0.000001
DECAY_FACTOR = 1.00004
DROPOUT = 0.5
INIT_MOMENTUM = 0.5
MAX_MOMENTUM = 0.7
BATCH_SIZE = 100
EPOCHS = 100

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [74]:
def train(gen, disc, optim_gen, optim_disc, scheduler_gen, scheduler_disc, criterion_gen, criterion_disc, train_loader, device):
    gen.train()
    disc.train()

    for x, y in train_loader:
        # x is the real data
        # y is the labels

        x = x.view(x.shape[0], -1).to(device)
        y = y.unsqueeze(1).to(device)


        optim_gen.zero_grad()
        optim_disc.zero_grad()
        
        z = torch.randn(x.shape[0], 784).to(device)
        gen_out = gen(z, y)

        # Discriminator predictions
        disc_out_real = disc(x, y)  # D(real)
        disc_out_fake = disc(gen_out.detach(), y)  # D(fake), detach G to avoid gradient flow to Generator

        # Create real and fake labels
        real_labels = torch.ones_like(disc_out_real)
        fake_labels = torch.zeros_like(disc_out_fake)

        # Compute Discriminator loss
        loss_disc_real = criterion_disc(disc_out_real, real_labels)  # D(x) should be 1
        loss_disc_fake = criterion_disc(disc_out_fake, fake_labels)  # D(G(z)) should be 0
        loss_disc = (loss_disc_real + loss_disc_fake) / 2

        loss_disc.backward()
        optim_disc.step()
 
        # Train Generator (G)
        # Recalculate fake images (since .detach() was used before)
        gen_out = gen(z, y)
        disc_out_fake = disc(gen_out, y)  # D(G(z)), should be classified as real

        # Compute Generator loss
        real_labels = torch.ones_like(disc_out_fake)  # Generator wants D to classify as real
        loss_gen = criterion_gen(disc_out_fake, real_labels)

        loss_gen.backward()
        optim_gen.step()
     
        scheduler_gen.step()
        scheduler_disc.step()

In [77]:
def test(gen, disc, criterion_gen, criterion_disc, test_loader, device):
    gen.eval()
    disc.eval()
    with torch.inference_mode():
        sample_gen = None
        disc_loss = 0 
        gen_loss = 0
        for x, y in test_loader:
            x = x.view(x.shape[0], -1).to(device)
            y = y.unsqueeze(1).to(device)

            z = torch.randn(x.shape[0], 784).to(device)
            gen_out = gen(z, y)

            disc_out_real = disc(x, y)
            disc_out_fake = disc(gen_out, y)

            real_labels = torch.ones_like(disc_out_real)
            fake_labels = torch.zeros_like(disc_out_fake)

            loss_disc_real = criterion_disc(disc_out_real, real_labels)
            loss_disc_fake = criterion_disc(disc_out_fake, fake_labels)
            loss_disc = (loss_disc_real + loss_disc_fake) / 2

            loss_gen = criterion_gen(disc_out_fake, real_labels)

            disc_loss += loss_disc.item()
            gen_loss += loss_gen.item()

            
            sample_gen  = torch.stack([x, gen_out], dim=1)
        
        return sample_gen, disc_loss/len(test_loader), gen_loss/len(test_loader)

In [28]:
def visualize(sample_gen):
    sample_gen = sample_gen.reshape(-1, 1, 28, 28)
    plt.figure(figsize=(15,15))
    plt.axis("off")
    plt.title("Generated Images")
    plt.imshow(np.transpose(vutils.make_grid(sample_gen, padding=2, normalize=True).cpu(), (1,2,0)))
    plt.show()

In [64]:
def train_and_test(gen, disc, optim_gen, optim_disc, scheduler_gen, scheduler_disc, criterion_gen, criterion_disc, train_loader, test_loader, epochs, device):
    for i in tqdm(range(epochs)):
        train(gen, disc, optim_gen, optim_disc, scheduler_gen, scheduler_disc, criterion_gen, criterion_disc, train_loader, device)
        sample_gen, disc_loss, gen_loss = test(gen, disc, criterion_gen, criterion_disc, test_loader, device)
        if (i+1) % 10 == 0:
            print(f"Epoch: {i}, Discriminator loss: {disc_loss}, Generator loss: {gen_loss}") 

    visualize(sample_gen)

#### Data (MNIST)

In [30]:
import torchvision.datasets as datasets

train_dataset = datasets.MNIST(
    root='./train_data', train=True, download=True, transform=transforms.ToTensor()
)
test_dataset = datasets.MNIST(
    root='./test_data', train=False, download=True, transform=transforms.ToTensor()
)


In [13]:
# Data loaders 
train_loader, test_loader = DataLoader(train_dataset,batch_size = BATCH_SIZE, shuffle= True), DataLoader(test_dataset,batch_size = BATCH_SIZE, shuffle= True)

#### Actual Training

In [82]:
gen = Generator(input_size_z=784, input_size_condition=1, output_size=784, hidden_size=256, layers=3, device=device)
disc = Discriminator(d_in = 784, d_label = 1, hidden_size=256, d_out =1, leaky = 0.2, dropout = DROPOUT,device=device)

optim_gen = optim.SGD(gen.parameters(), lr=GEN_LR, momentum=INIT_MOMENTUM, weight_decay=0.0001)
optim_disc = optim.SGD(disc.parameters(), lr=DISC_LR, momentum=INIT_MOMENTUM,weight_decay=0.01)

scheduler_gen = optim.lr_scheduler.ExponentialLR(optim_gen, DECAY_FACTOR) 
scheduler_disc = optim.lr_scheduler.ExponentialLR(optim_disc, DECAY_FACTOR)

criterion_gen = nn.BCEWithLogitsLoss() # we use the raw logits
criterion_disc = nn.BCEWithLogitsLoss() # we use the raw logits


gen = gen.to(device)
disc = disc.to(device)

train_and_test(gen, disc, optim_gen, optim_disc, scheduler_gen, scheduler_disc, criterion_gen, criterion_disc, train_loader, test_loader, EPOCHS, device)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 9, Discriminator loss: 0.6941260331869126, Generator loss: 0.6868000757694245


KeyboardInterrupt: 