## Week 13 : Generative Adversarial Networks
```
- Machine Learning, Innopolis University 
- Professor: Adil Khan 
- Teaching Assistant: Gcinizwe Dlamini
```
<hr>


```
Lab Plan
    1. Vanila GAN achitecture 
    2. GAN training procedure
```

<hr>

## 1. Vannila Generative adversarial network (GAN)

![caption](https://www.researchgate.net/profile/Zhaoqing-Pan/publication/331756737/figure/fig1/AS:736526694621184@1552613056409/The-architecture-of-generative-adversarial-networks.png)

### 1.1 Dataset 

For this lesson we will use SVHN dataset which readily available in `torchvision` and we will do minimal transformation operations 

### Task : Normalize the data

In [None]:
# import libraries
import matplotlib.pyplot as plt
import numpy as np

import torch
from torchvision import datasets
from torchvision import transforms


def normalize(data_tensor):
    '''re-scale image values to [-1, 1]'''
    pass

transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: normalize(x))])

# SVHN training datasets
svhn_train = datasets.SVHN(root='data/', split='train', download=True, transform=transform)

batch_size = 128
num_workers = 0

# build DataLoaders for SVHN dataset
train_loader = torch.utils.data.DataLoader(dataset=svhn_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=num_workers)

## 1.2 Generator & Discriminator Definition

There are a couple of ways to increase the input of the generator (*z*) to the desired output size.
1. Number of neurones
2. Transposed Convolutions `torch.nn.ConvTranspose2d` [More info](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html)

### TASK : Define the generator and discriminator network using the architectures specified below

#### Discriminator : <br>
```
1. conv layer 1 -> output channels 32, kernel size 4x4, stride 2x2
2. conv layer 2 -> output channels 64, kernel size 4x4, stride 2x2
3. Add batch normalization & Leaky ReLU activation 
4. conv layer 3 -> output channels 1, kernel_size 4x4, stride 1x1
5. Add batch normalization
6. Flatten layer
7. Output layer
```

#### Generator : <br>
```
1. Transpose 2d layer -> output channels 6, kernel size 4x4, stride 2x2, padding 1
2. Add batch normalization & Leaky Tanh activation 
3. Transpose 2d layer -> output channels 3, kernel size 4x4, stride 2x2, padding 1 
4. Batch normalization & Leaky Tanh activation 
5. Transpose 2d layer -> output channels 3, kernel size 4x4, stride 2x2, padding 1 
6. Tanh activation 
```

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Discriminator(nn.Module):

    def __init__(self, conv_dim=32):
        super(Discriminator, self).__init__()
        #TODO: 
        self.model = nn.Sequential(...)

    def forward(self, x):
        # Step 1: pass the input (real or fake samples) through all hidden layers
        pass

class Generator(nn.Module):
    
    def __init__(self, z_size, conv_dim=32):
        super(Generator, self).__init__()
        # Step 1: Define the generator network architecture
        # NOTE: the input is the random noise size and output is conv_dim i.e (3,32,32)
        self.model = nn.Sequential(...) 
        

    def forward(self, x):
        # Step 1: pass the input which is random noise to generate the fake samples
        return None

## 1.3 Set hyperparams and training parameters

### Task : create discriminator and generator network

In [None]:
# define hyperparams
conv_dim = 32
z_size = 100
num_epochs = 50

# TODO: define discriminator and generator and send it to device 
D = None
G = None

#print the models summary 
print(D)
print()
print(G)

## 1.4 Define the loss function for D(x) and G(x)

In [None]:
import torch.optim as optim

def real_loss(D_out, smooth=False):
    batch_size = D_out.size(0)
    # label smoothing
    if smooth:
        # smooth, real labels
        labels = torch.FloatTensor(batch_size).uniform_(0.9, 1).to(device)
    else:
        labels = torch.ones(batch_size).to(device) # real labels = 1     
    
    # binary cross entropy with logits loss
    criterion = nn.BCEWithLogitsLoss()
    # calculate loss
    loss = criterion(D_out.squeeze(), labels)
    return loss

def fake_loss(D_out):
    batch_size = D_out.size(0)
    labels = torch.FloatTensor(batch_size).uniform_(0, 0.1).to(device) # fake labels approx 0
    labels = labels.to(device)
    criterion = nn.BCEWithLogitsLoss()
    # calculate loss
    loss = criterion(D_out.squeeze(), labels)
    return loss

# params
learning_rate = 0.0003
beta1=0.5
beta2=0.999 # default value

# Create optimizers for the discriminator and generator
d_optimizer = None
g_optimizer = None

## 1.5 GAN training Loop

### Task 

1. Implement GAN training procedure
1. (optional) : Add TensorBoard to monitor the generator and discriminator loss 

In [None]:

# keep track of loss and generated, "fake" samples
losses = []

print_every = 2

# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()

# train the network
for epoch in range(num_epochs):
    
    for batch_i, (real_images, _) in enumerate(train_loader):
                
        batch_size = real_images.size(0)
        
        
        # TODO: TRAIN THE DISCRIMINATOR
        # Step 1: Zero gradients (zero_grad)
        # Step 2: Train with real images
        # Step 3: Compute the discriminator losses on real images 
        
        D_real = None
        d_real_loss = real_loss(D_real)
        
        # Step 4: Train with fake images
        # Step 5: Generate fake images and move x to GPU, if available
        # Step 6: Compute the discriminator losses on fake images 
        # Step 7: add up loss and perform backprop
        
        fake_images = None     
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        
        #TODO: TRAIN THE GENERATOR (Train with fake images and flipped labels)
        g_optimizer.zero_grad()
        
        # Step 1: Zero gradients  
        # Step 2: Generate fake images from random noise (z)
        # Step 3: Compute the discriminator losses on fake images using flipped labels!
        # Step 4: Perform backprop and take optimizer step

    # TODO: Print some loss stats
    if epoch % print_every == 0:
        pass

Keep in mind:

1. Always use a learning rate for discriminator higher than the generator.

2. Keep training even if you see that the losses are going up.

3. There are many variations with different loss functions which are worth exploring.

4. If you get mode collapse, lower the learning rates.

5. Adding noise to the training data helps make the model more stable.

6. Label Smoothing: instead of making the labels as 1 make it 0.9 


## References

1. [Deep Convolutional Generative Adversarial Network](https://www.tensorflow.org/tutorials/generative/dcgan)

1. [Generative adversarial networks: What GANs are and how they’ve evolved](https://venturebeat.com/2019/12/26/gan-generative-adversarial-network-explainer-ai-machine-learning/)

1. [Generative Adversarial Nets](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf)

1. [GANs by google](https://developers.google.com/machine-learning/gan)

1. [A Gentle Introduction to Generative Adversarial Networks (GANs)](https://machinelearningmastery.com/what-are-generative-adversarial-networks-gans/)

1. [A Beginner's Guide to Generative Adversarial Networks (GANs)](https://pathmind.com/wiki/generative-adversarial-network-gan)

1. [Understanding Generative Adversarial Networks (GANs)](https://towardsdatascience.com/understanding-generative-adversarial-networks-gans-cd6e4651a29)

1. [Deep Learning (PyTorch)](https://github.com/udacity/deep-learning-v2-pytorch)

1. [10 Lessons I Learned Training GANs for one Year](https://towardsdatascience.com/10-lessons-i-learned-training-generative-adversarial-networks-gans-for-a-year-c9071159628)