## Conditional Generative Adversarial Network (CGAN)

Paper:
- [Conditional Generative Adversarial Nets](https://arxiv.org/pdf/1411.1784)

Helpful Resources:
- [Aladdin Persson's playlist on GANs](https://youtube.com/playlist?list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&si=8ooImkbbXhCUC1xB)
- [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans)  $\; \rightarrow \;$ most of the content of this notebook has been borrowed from this course
- [Stanford's Deep Generative Models playlist](https://youtube.com/playlist?list=PLoROMvodv4rPOWA-omMM6STXaWW4FvJT8&si=N_TpTe1bPIhte-t8)

This notebook just includes the implementation of the CGAN model and its training loop. The results are not shown here.

Feel free to check the results on my Kaggle notebook: https://www.kaggle.com/code/aryamanbansal/conditional_gan

#### Imports and some helpful utility functions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torchinfo import summary

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import glob
from datetime import datetime
import pytz
import time
import os
from tqdm import tqdm
import random
import gc

%matplotlib inline

print("Imports done!")

Imports done!


In [None]:
def get_torch_version():
    """
    Returns the version of PyTorch installed.
    """
    torch_version = torch.__version__.split("+")[0]
    torch_number = torch_version.split(".")[:2]
    torch_number_float = torch_number[0] + "." + torch_number[1]
    torch_number_float = float(torch_number_float)
    return torch_number_float


def set_seed(seed: int = 42):
    """
    Seeds basic parameters for reproducibility of results
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        if get_torch_version() <= 1.7:
            torch.set_deterministic(True)
        else:
            torch.use_deterministic_algorithms(True)


def set_scheduler(scheduler, results, scheduler_on):
    """
    Makes the neccessary updates to the scheduler.
    
    Parameters:
    - scheduler: torch.optim.lr_scheduler, the scheduler to update.
    - results: dict, the results dictionary containing the training and test metric values.
        Keys in the results dictionary are: "gen_train_loss", "gen_val_loss", 
                                            "disc_train_loss", "disc_val_loss".
    - scheduler_on: str, the metric to use for the scheduler update.
    """
    if scheduler_on == "gen_val_loss":
        scheduler.step(results["gen_val_loss"][-1])
    if scheduler_on == "disc_val_loss":
        scheduler.step(results["disc_val_loss"][-1])
    elif scheduler_on == "gen_train_loss":
        scheduler.step(results["gen_train_loss"][-1])
    elif scheduler_on == "disc_train_loss":
        scheduler.step(results["disc_train_loss"][-1])
    else:
        raise ValueError("Invalid `scheduler_on` choice.")
    return scheduler


def save_model_info(path: str, model, model_name, optimizer, optimizer_name, 
                    scheduler=None, scheduler_name=""):
    """
    Saves the model and optimizer weights to the specified path.

    Parameters:
    - path: str, the path to the directory to save the model and optimizer weights.
    - model: nn.Module, the model to save.
    - model_name: str, the name of the model weights file.
    - optimizer: torch.optim, the optimizer to save.
    - optimizer_name: str, the name of the optimizer weights file.
    - scheduler: torch.optim.lr_scheduler, the scheduler to save.
    - scheduler_name: str, the name of the scheduler weights file.

    Free advice:
    A good practice is to transfer the model to the CPU before calling torch.save as this 
    will save tensors as CPU tensors and not as CUDA tensors. This will help in loading
    the model onto any machine, whether it contains CUDA capabilities or not.
    """
    model.to("cpu")
    torch.save(model.state_dict(), os.path.join(path,model_name))
    torch.save(optimizer.state_dict(), os.path.join(path,optimizer_name))
    if scheduler is not None:
        torch.save(scheduler.state_dict(), os.path.join(path,scheduler_name))    
    print("Model info saved!")
    
    
def load_model_info(path, device, model, model_name, optimizer, optimizer_name, 
                    scheduler=None, scheduler_name=""):
    """
    Loads the model and optimizer weights from the specified path.

    Parameters:
    - path: str, the path to the directory containing the model and optimizer weights.
    - device: str, the device to load the model and optimizer weights onto.
    - model: nn.Module, the model to load the weights into.
    - model_name: str, the name of the model weights file.
    - optimizer: torch.optim, the optimizer to load the weights into.
    - optimizer_name: str, the name of the optimizer weights file.
    - scheduler: torch.optim.lr_scheduler, the scheduler to load the weights into.
    - scheduler_name: str, the name of the scheduler weights file.
    """
    model.load_state_dict(torch.load(os.path.join(path,model_name)))
    model.to(device)
    optimizer.load_state_dict(torch.load(os.path.join(path,optimizer_name)))
    if scheduler is not None:
        scheduler.load_state_dict(torch.load(os.path.join(path,scheduler_name)))
    print("Model info loaded!")
    
    
def get_current_time():
    """Returns the current time in Toronto."""
    now = datetime.now(pytz.timezone('Canada/Eastern'))
    current_time = now.strftime("%d_%m_%Y__%H_%M_%S")
    return current_time


def show_tensor_images(image_tensor, num_images=25, size=(1,28,28)):
    """
    Function for visualizing images: Given a tensor of images, 
    number of images, and size per image, plots and prints the 
    images in an uniform grid.
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()


def visualize_results(results):
    """
    Plots the train and test losses of the generator and the critic.

    results is dictionary with keys: 
        "gen_train_loss", "gen_val_loss", "crit_train_loss", "crit_val_loss".
    """
    gen_train_loss = results["gen_train_loss"]
    gen_val_loss = results["gen_val_loss"]
    crit_train_loss = results["crit_train_loss"]
    crit_val_loss = results["crit_val_loss"]
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,4))
    ax[0].set_xlabel("Epochs")
    ax[0].set_ylabel("Training Loss")
    ax[0].plot(gen_train_loss, label="generator", color="orange")
    ax[0].plot(crit_train_loss, label="critic", color="blue")
    ax[0].legend()
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("Validation Loss")
    ax[1].plot(gen_val_loss, label="generator", color="orange")
    ax[1].plot(crit_val_loss, label="critic", color="blue")
    ax[1].legend()
    plt.show()


print("Utility functions created!")

Utility functions created!


## Some theory about CGANs

Before we jump in to the code, let's understand the theory behind CGANs, and why we need them. 

Note: the following notes are heavily inspired from the [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans).


#### **Motivation**

Let's first understand the difference between unconditional and conditional generation.

Up till now, the models we built (simple GAN, DCGAN, and WGAN) were all unconditional GANs. This means that we input some random noise vector into the generator, and it generates a random image. So, we don't have any control over what our GAN is generating. In addition, we don't need the dataset to be labelled, as an unconditional GAN just needs the datapoints and not their corresponding labels.

On the other hand, in conditional generation, we can control what our GAN generates. We can input some information to the generator, and it will generate an image based on that information. For example, we can input a label to the generator, and it will generate an image of that label. However, the generator can generator any variation of that label, and not just the exact same image of what you have in mind. Consider this analogy: we go to a vending machine and insert some money along with the itemcode of the item we want. The vending machine will give us the item corresponding to that itemcode, but we can't choose all the features (say, expiration date or manufacturing date) of that item. So, when using a conditional GAN, we input the label of the class we want, and the generator generates some example of that class. In addition, the dataset needs to be labelled, as we need to input the label to the generator.

#### **How to make a GAN conditional?**

> The input to the generator of a CGAN.

We've already seen with unconditional generation that the generator needs a noise vector to produce random examples. For conditional generation, we also need a vector to tell the generator from which class the generated examples should come from. This new vector is called a class vector, which is usually a one-hot vector, which means that there are zeros in every position except for one position corresponding to the class we want. The input to the generator in a conditional GAN is actually a concatenated vector of both the noise vector and the class vector. Here, the class vector tells the generator which class to generate and the noise vector tells it what variation of that class to generate. Of course, since the noise vector is random, so we cannot control the exact variation of the class that the generator generates. If we use a different noise vector every time but the same class vector, the generator will generate different variations of the same class. 

!["cgan_gen_input"](./imgs/cgan_gen_input.png "cgan_gen_input")

Source: [Course 1 Week 4 of GANs specialization on Coursera](https://www.coursera.org/learn/build-basic-generative-adversarial-networks-gans?specialization=generative-adversarial-networks-gans)

<br>

> The input to the discriminator of a CGAN.

In the case of unconditional GANs, the input to the discriminator was just the image. To make the GAN conditional, we need some way of conveying the class information to discriminator. We saw above that the input to the generator was a concatenated vector of the noise vector and the class vector, where the class vector was a one-hot vector. Since we're feeding the discriminator an image (which can either be RGB or grayscale) and not just a noise vector, so we'll need to adjust the class vector to be the same size as the image. We can do this by conveying class information in the form of additional channels, where all the newly added channels are matrices filled with zeros except for one channel corresponding to the class we want that is filled with ones. We can see how this can be done as in the below diagram.

!["cgan_disc_input"](./imgs/cgan_disc_input.png "cgan_disc_input")

Source: [Course 1 Week 4 of GANs specialization on Coursera](https://www.coursera.org/learn/build-basic-generative-adversarial-networks-gans?specialization=generative-adversarial-networks-gans)

There are many other less space consuming ways of conveying class information to the discriminator, like compressing the class information in another format. We can even create a separate neural network head to do that for us, which would be wise if we had many different classes. However, when the number of classes is around 10 or so, then it's alright to just use the above described method.

<br>

> How does the generator learn to generate the correct class?

The discriminator receives both the image and its corresponding class label and will classify the images based on if they look like real images from that specific class or not. For example, consider that we're building a CGAN to generate different species of dogs. Now, if the generator is asked to generate an image of a golden retriever, but if it generates an image of a poodle (no matter how realistic-looking that image might be), then the discriminator will give a very low probability of the generated image being real as it's not a golden retriever. So, the generator will learn to generate images that belong to whatever class is input to it, and not just any random image.

## Alright enough theory, let's write some code now!

#### Loading the data