## Week 4 : Generative Adversarial Networks
```
- Generative Artificial Intelligence (Fall semester 2023)
- Professor: Muhammad Fahim
- Teaching Assistant: Gcinizwe Dlamini
```
<hr>


```
Lab Plan
    1. Vanila GAN achitecture
    2. GAN training procedure
    3. Conditional GAN
```

<hr>

## 1. Vannila Generative adversarial network (GAN)

![caption](https://www.researchgate.net/profile/Zhaoqing-Pan/publication/331756737/figure/fig1/AS:736526694621184@1552613056409/The-architecture-of-generative-adversarial-networks.png)

### 1.1 Dataset

For this lesson we will use SVHN dataset which readily available in `torchvision` and we will do minimal transformation operations

Install `torchvision` : `pip install torchvision`

In [None]:
# import libraries
import matplotlib.pyplot as plt
import numpy as np

import torch
from torchvision import datasets
from torchvision import transforms


transform = transforms.Compose([transforms.Resize([32, 32]),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5], [0.5])])

# SVHN training datasets
svhn_train = datasets.SVHN(root='data/', split='train', download=True, transform=transform)

batch_size = 256
num_workers = 0

# build DataLoaders for SVHN dataset
train_loader = torch.utils.data.DataLoader(dataset=svhn_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=num_workers)

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to data/train_32x32.mat


100%|██████████| 182040794/182040794 [00:02<00:00, 90779375.10it/s] 


## 1.2 Generator & Discriminator Definition

In [None]:
import torch.nn as nn
import torch.nn.functional as F
#ngf : Number of generator filters
#ndf : Number of discriminator filters
nz = 32
class Discriminator(nn.Module):

    def __init__(self, ndf=3, conv_dim=32):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, 1, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1),
            nn.Flatten(),
            nn.Linear(5*5,1),
            nn.Sigmoid()
          )


    def forward(self, x):
        # Step 1: pass the input (real or fake samples) through all hidden layers
        return self.model(x)

class Generator(nn.Module):

    def __init__(self, z_size, ngf, conv_dim=32):
        super(Generator, self).__init__()
        # Step 1: Define the generator network architecture
        # NOTE: the input is the random noise size and output is conv_dim i.e (3,32,32)
        self.conv_dim = conv_dim
        self.input_layer = nn.Linear(in_features=z_size, out_features=2048, bias=True)
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 128, out_channels=ngf * 2, kernel_size=4,stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features= ngf * 2),
            nn.Tanh(),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_features=ngf),
            nn.Tanh(),
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )


    def forward(self, x):
      # Step 1: pass the input which is random noise to generate the face samples
      x = self.input_layer(x)
      x = x.view(-1, self.conv_dim*4, 4, 4) # (batch_size, depth, 4, 4)
      return self.model(x)

![](https://www.researchgate.net/publication/336144594/figure/fig2/AS:808881324322820@1569863744938/An-example-of-the-deconvolution-process-using-transpose-convolution-In-the-figure.png)

## 1.3 Set hyperparams and training parameters

In [None]:
# define hyperparams
conv_dim = 32
z_size = 100
num_epochs = 10

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# define discriminator and generator
D = Discriminator(conv_dim).to(device)
G = Generator(z_size=z_size, ngf=3,conv_dim=conv_dim).to(device)

#print the models summary
print(D)
print()
print(G)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(64, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (6): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Flatten(start_dim=1, end_dim=-1)
    (8): Linear(in_features=25, out_features=1, bias=True)
  )
)

Generator(
  (input_layer): Linear(in_features=100, out_features=2048, bias=True)
  (model): Sequential(
    (0): ConvTranspose2d(128, 6, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Tanh()
    (3): ConvTranspose2d(6, 3, kernel_size=(4, 4

## 1.4 Define the loss function for D(x) and G(x)

In [None]:
import torch.optim as optim

def real_loss(D_out, smooth=False):
    batch_size = D_out.size(0)
    # label smoothing
    if smooth:
        # smooth, real labels
        labels = torch.FloatTensor(batch_size).uniform_(0.9, 1).to(device)
    else:
        labels = torch.ones(batch_size) # real labels = 1
    # move labels to GPU if available

    labels = labels.to(device)
    # binary cross entropy with logits loss
    criterion = nn.BCELoss()
    # calculate loss
    loss = criterion(D_out.squeeze(), labels)
    return loss

def fake_loss(D_out):
    batch_size = D_out.size(0)
    labels = torch.FloatTensor(batch_size).uniform_(0, 0.1).to(device) # fake labels = 0
    labels = labels.to(device)
    criterion = nn.BCEWithLogitsLoss()
    # calculate loss
    loss = criterion(D_out.squeeze(), labels)
    return loss

# params
learning_rate = 0.0003
beta1=0.5
beta2=0.999 # default value

# Create optimizers for the discriminator and generator
d_optimizer = optim.Adam(D.parameters(), learning_rate)
g_optimizer = optim.SGD(G.parameters(), learning_rate)

## 2. GAN training Loop

In [None]:
# Logging
print_every = 2

# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()

# train the network
for epoch in range(num_epochs):
  g_l = 0
  d_l = 0
  for batch_i, (real_images, _) in enumerate(train_loader):

    batch_size = real_images.size(0)


    # TRAIN THE DISCRIMINATOR
    # Step 1: Zero gradients (zero_grad)
    # Step 2: Train with real images
    # Step 3: Compute the discriminator losses on real images
    d_optimizer.zero_grad()
    real_images = real_images.to(device)
    D_real = D(real_images)
    d_real_loss = real_loss(D_real)

    # Step 4: Train with fake images
    # Step 5: Generate fake images and move x to GPU, if available
    # Step 6: Compute the discriminator losses on fake images
    # Step 7: add up loss and perform backprop

    z = torch.FloatTensor(batch_size, z_size).uniform_(-1, 1).to(device)
    fake_images = G(z)

    D_fake = D(fake_images)
    d_fake_loss = fake_loss(D_fake)

    d_loss = d_real_loss + d_fake_loss
    d_l += d_loss.item()
    d_loss.backward()
    d_optimizer.step() # The


    #TRAIN THE GENERATOR (Train with fake images and flipped labels)
    g_optimizer.zero_grad()

    # Step 1: Zero gradients
    # Step 2: Generate fake images from random noise (z)
    # Step 3: Compute the discriminator losses on fake images using flipped labels!
    # Step 4: Perform backprop and take optimizer step
    z = torch.FloatTensor(batch_size, z_size).uniform_(-1, 1).to(device)

    fake_images = G(z)

    D_fake = D(fake_images)
    g_loss = real_loss(D_fake)
    g_l += g_loss.item()

    g_loss.backward()
    g_optimizer.step()


  # Print some loss stats
  if epoch % print_every == 0:
    print("Epoch: " + str(epoch + 1) + "/" + str(num_epochs)
          + "\td_loss:" + str(round(d_l/len(train_loader), 4))
          + "\tg_loss:" + str(round(g_l/len(train_loader), 4))
          )

Keep in mind:

1. Always use a learning rate for discriminator higher than the generator.
2. Keep training even if you see that the losses are going up.
3. There are many variations with different loss functions which are worth exploring.
4. If you get mode collapse, lower the learning rates.
5. Adding noise to the training data helps make the model more stable.
6. Label Smoothing: instead of making the labels as 1 make it 0.9


## 3. Conditional GAN

![](https://www.researchgate.net/profile/Gerasimos-Spanakis/publication/330474693/figure/fig1/AS:956606955139072@1605084279074/GAN-conditional-GAN-CGAN-and-auxiliary-classifier-GAN-ACGAN-architectures-where-x_Q320.jpg)

### 3.1 Read Data

In [None]:
import torch
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

transform = transforms.Compose([transforms.Resize([32, 32]),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5], [0.5])])

# SVHN training datasets
svhn_train = datasets.SVHN(root='data/', split='train', download=True, transform=transform)

batch_size = 256
num_workers = 0

# build DataLoaders for SVHN dataset
train_loader = torch.utils.data.DataLoader(dataset=svhn_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=num_workers)

### 3.2 Define helper functions

In [None]:
def conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
  module = []
  if transpose:
    module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
  else:
    module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
  if use_bn:
    module.append(nn.BatchNorm2d(c_out))
  return nn.Sequential(*module)

### 3.3 Define Generator

<font color='red'>**TODO:** Define Generator using achitecture of your choice</font>

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim=10, num_classes=10, label_embed_size=5, channels=3, conv_dim=64):
    super(Generator, self).__init__()
    self.image_size = 32
    self.label_embedding = nn.Embedding(num_classes, label_embed_size)
    self.l1 = conv_block(z_dim + label_embed_size, conv_dim * 4, pad=0, transpose=True)
    self.l2 = None
    self.l3 = None
    self.l4 = None

    for m in self.modules():
      if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
      if isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def forward(self, x, condition):
    x = x.reshape([x.shape[0], -1, 1, 1])
    condition_embed = self.label_embedding(condition)
    condition_embed = condition_embed.reshape([condition_embed.shape[0], -1, 1, 1])
    x = torch.cat((x, condition_embed), dim=1)
    x = None # TODO
    return x

### 3.4 Define Discriminator

<font color='red'>**TODO:** Define Discriminator using achitecture of your choice</font>

In [None]:
class Discriminator(nn.Module):
  def __init__(self, num_classes=10, channels=3, conv_dim=64):
    super(Discriminator, self).__init__()
    self.image_size = 32
    self.condition_embedding = nn.Embedding(num_classes, self.image_size*self.image_size)
    self.conv1 = conv_block(channels + 1, conv_dim, use_bn=False)
    self.conv2 = None
    self.conv3 = None
    self.conv4 = None

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight, 0.0, 0.02)

      if isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def forward(self, x, condition):
    alpha = 0.2
    condition_embed = self.condition_embedding(condition)
    condition_embed = condition_embed.reshape([condition_embed.shape[0], 1, self.image_size, self.image_size])
    x = torch.cat((x, condition_embed), dim=1)
    x = None
    return x.squeeze()

### 3.5 Assemble a cGAN

In [None]:
# define discriminator and generator
z_dim = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


G = Generator(z_dim=z_dim, num_classes=10, label_embed_size=5, channels=3).to(device)
D = Discriminator(num_classes=10, channels=3).to(device)

#print the models summary
print(D)
print()
print(G)

### 3.6 Define optimizer and criterion

In [None]:
import torch.optim as optim

g_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)
d_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)


criterion = nn.BCELoss()

## 3.7 Training conditional GAN (training loop)


<font color='red'>**TODO:** Train conditional GAN</font>


In [None]:
# Training
num_epochs = 5

# Labels
real_label = torch.FloatTensor(batch_size).uniform_(0.9, 1).to(device) # torch.ones(batch_size)
fake_label = torch.FloatTensor(batch_size).uniform_(0, 0.1).to(device)


for epoch in range(num_epochs):
  G.train()
  D.train()
  for batch_i, (x_real, y_real) in enumerate(train_loader):
    batch_size = x_real.size(0)
    x_real = x_real.to(device)
    y_real = y_real.to(device)

    # TODO
    # TRAIN THE DISCRIMINATOR
    # Step 1: Zero gradients (zero_grad)
    # Step 2: Train with real images
    # Step 3: Compute the discriminator losses on real images

    # Step 4: Train with fake images
    # Step 5: Generate fake images and move x to GPU, if available
    # Step 6: Compute the discriminator losses on fake images
    # Step 7: add up loss and perform backprop


    #TRAIN THE GENERATOR
    # Step 1: Zero gradients
    # Step 2: Generate fake images from random noise (z) and condition (y)
    # Step 3: Compute the discriminator losses on fake images using flipped labels (labels -- true/fake)
    # Step 4: Perform backprop and take optimizer step

  # Print the loss for each epoch

## Resources

* [Deconvolutional Networks](https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf)
* [PyTorch `ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html)
* [Computational Imaging and Display](https://stanford.edu/class/ee367/reading/lecture6_notes.pdf)