## Wasserstein Generative Adversarial Network (WGAN)

Paper: [Wasserstein GAN](https://arxiv.org/pdf/1701.07875)

Helpful Resources:
- [Aladdin Persson's playlist on GANs](https://youtube.com/playlist?list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&si=8ooImkbbXhCUC1xB)
- [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans)
- [Stanford's Deep Generative Models playlist](https://youtube.com/playlist?list=PLoROMvodv4rPOWA-omMM6STXaWW4FvJT8&si=N_TpTe1bPIhte-t8)

This notebook just includes the implementation of the WGAN model and its training loop. The results are not shown here.

Feel free to check the results on my Kaggle notebook: https://www.kaggle.com/code/aryamanbansal/wgan

#### Imports and some helpful utility functions

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torchinfo import summary

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import glob
from datetime import datetime
import pytz
import time
import os
from tqdm import tqdm
import random
import gc

%matplotlib inline

print("Imports done!")

Imports done!


In [9]:
def get_torch_version():
    """
    Returns the version of PyTorch installed.
    """
    torch_version = torch.__version__.split("+")[0]
    torch_number = torch_version.split(".")[:2]
    torch_number_float = torch_number[0] + "." + torch_number[1]
    torch_number_float = float(torch_number_float)
    return torch_number_float


def set_seed(seed: int = 42):
    """
    Seeds basic parameters for reproducibility of results
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        if get_torch_version() <= 1.7:
            torch.set_deterministic(True)
        else:
            torch.use_deterministic_algorithms(True)


def set_scheduler(scheduler, results, scheduler_on):
    """
    Makes the neccessary updates to the scheduler.
    
    Parameters:
    - scheduler: torch.optim.lr_scheduler, the scheduler to update.
    - results: dict, the results dictionary containing the training and test metric values.
        Keys in the results dictionary are: "gen_train_loss", "gen_val_loss", 
                                            "disc_train_loss", "disc_val_loss".
    - scheduler_on: str, the metric to use for the scheduler update.
    """
    if scheduler_on == "gen_val_loss":
        scheduler.step(results["gen_val_loss"][-1])
    if scheduler_on == "disc_val_loss":
        scheduler.step(results["disc_val_loss"][-1])
    elif scheduler_on == "gen_train_loss":
        scheduler.step(results["gen_train_loss"][-1])
    elif scheduler_on == "disc_train_loss":
        scheduler.step(results["disc_train_loss"][-1])
    else:
        raise ValueError("Invalid `scheduler_on` choice.")
    return scheduler


def save_model_info(path: str, model, model_name, optimizer, optimizer_name, 
                    scheduler=None, scheduler_name=""):
    """
    Saves the model and optimizer weights to the specified path.

    Parameters:
    - path: str, the path to the directory to save the model and optimizer weights.
    - model: nn.Module, the model to save.
    - model_name: str, the name of the model weights file.
    - optimizer: torch.optim, the optimizer to save.
    - optimizer_name: str, the name of the optimizer weights file.
    - scheduler: torch.optim.lr_scheduler, the scheduler to save.
    - scheduler_name: str, the name of the scheduler weights file.

    Free advice:
    A good practice is to transfer the model to the CPU before calling torch.save as this 
    will save tensors as CPU tensors and not as CUDA tensors. This will help in loading
    the model onto any machine, whether it contains CUDA capabilities or not.
    """
    model.to("cpu")
    torch.save(model.state_dict(), os.path.join(path,model_name))
    torch.save(optimizer.state_dict(), os.path.join(path,optimizer_name))
    if scheduler is not None:
        torch.save(scheduler.state_dict(), os.path.join(path,scheduler_name))    
    print("Model info saved!")
    
    
def load_model_info(path, device, model, model_name, optimizer, optimizer_name, 
                    scheduler=None, scheduler_name=""):
    """
    Loads the model and optimizer weights from the specified path.

    Parameters:
    - path: str, the path to the directory containing the model and optimizer weights.
    - device: str, the device to load the model and optimizer weights onto.
    - model: nn.Module, the model to load the weights into.
    - model_name: str, the name of the model weights file.
    - optimizer: torch.optim, the optimizer to load the weights into.
    - optimizer_name: str, the name of the optimizer weights file.
    - scheduler: torch.optim.lr_scheduler, the scheduler to load the weights into.
    - scheduler_name: str, the name of the scheduler weights file.
    """
    model.load_state_dict(torch.load(os.path.join(path,model_name)))
    model.to(device)
    optimizer.load_state_dict(torch.load(os.path.join(path,optimizer_name)))
    if scheduler is not None:
        scheduler.load_state_dict(torch.load(os.path.join(path,scheduler_name)))
    print("Model info loaded!")
    
    
def get_current_time():
    """Returns the current time in Toronto."""
    now = datetime.now(pytz.timezone('Canada/Eastern'))
    current_time = now.strftime("%d_%m_%Y__%H_%M_%S")
    return current_time


def show_tensor_images(image_tensor, num_images=25, size=(1,28,28)):
    """
    Function for visualizing images: Given a tensor of images, 
    number of images, and size per image, plots and prints the 
    images in an uniform grid.
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()


def visualize_results(results):
    """
    Plots the train and test losses of the generator and the discriminator.

    results is dictionary with keys: 
        "gen_train_loss", "gen_val_loss", "disc_train_loss", "disc_val_loss".
    """
    gen_train_loss = results["gen_train_loss"]
    gen_val_loss = results["gen_val_loss"]
    disc_train_loss = results["disc_train_loss"]
    disc_val_loss = results["disc_val_loss"]
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,4))
    ax[0].set_xlabel("Epochs")
    ax[0].set_ylabel("Training Loss")
    ax[0].plot(gen_train_loss, label="generator", color="orange")
    ax[0].plot(disc_train_loss, label="discriminator", color="blue")
    ax[0].legend()
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("Test Loss")
    ax[1].plot(gen_val_loss, label="generator", color="orange")
    ax[1].plot(disc_val_loss, label="discriminator", color="blue")
    ax[1].legend()
    plt.show()


print("Utility functions created!")

Utility functions created!


### Some theory about WGANs

Before we jump in to the code, let's understand the theory behind WGANs, and why we need them. 

Note: the following notes are heavily inspired by the [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans).


#### **Motivation**

In simple GANs using binary cross entropy (BCE) loss, the generator tries to make the discriminator think its fake images are real. The issue arises because the generator's learning depends entirely on how the discriminator evaluates its images.

When the discriminator strongly classifies an image as real (assigns higher probability to that image as real), the generator focuses heavily on producing that image repeatedly, leading to a lack of variety. Conversely, if the discriminator strongly classifies the image as fake, the generator receives almost no useful feedback, as BCE loss gives a vanishing gradient when predictions are very confident. This stalling of the generator’s learning leads to the bad scenario where the generator is stuck with no clear way to improve.

This phenomenon occurs because BCE loss measures how well the discriminator classifies an image but doesn't give the generator useful signals when its outputs are far from realistic.

#### **Mode Collapse**

- **What is it**: Mode collapse happens when the generator produces limited types of images (e.g., just one kind of dog) instead of a diverse set of realistic outputs.

- **When does it happen**: Mode collapse occurs when the generator gets stuck generating one mode. The discriminator will eventually learn to differentiate the generator's fakes when this happens and outskill it, ending the model's learning. In other words, mode collapse often occurs when the generator "overfits" to fool the discriminator by creating only a few images that it knows will work well, ignoring the rest of the data distribution. It generally occurs when we use the BCE loss function to train GANs.

- **Why does it happen**: This happens because the discriminator doesn't penalize the generator for repeatedly producing the same outputs, as long as they seem realistic. The generator's objective doesn’t explicitly encourage diversity.

#### **Problem with BCE Loss function**

**TL;DR**: The BCE loss function is not a good choice for GANs because it can lead to vanishing gradient problem and mode collapse. In other words, the discriminator gets too good at distinguishing between real and fake values, and as a result, the generator doesn't get useful feedback to improve.

> Binary Cross Entropy (BCE) loss function review. 

The BCE loss function is defined as:

$$\text{BCELoss} = -\frac{1}{m} \sum_{i=1}^m \left[ y_i \cdot \text{log}(x_i) + (1 - y_i) \cdot \text{log}(1 - x_i) \right]$$

A simplified form of the above loss function is:

$$\text{BCELoss} \quad = \quad \min\limits_{d} \max\limits_{g} \quad - \left[ \mathbb{E}(\text{log}(d(x))) \; + \; \mathbb{E}(1 - \text{log}(d(g(z)))) \right]$$

, where:

- $d$ is the discriminator
- $g$ is the generator
- $x$ is the real values
- $z$ is the noise vector
- $d(x)$ is the discriminator's prediction for real values
- $g(z)$ is the fake values generated by the generator
- $d(g(z))$ is the discriminator's prediction for fake values


The first term ($y_i \cdot \text{log}(x_i)$) is for real values, and the second term ($(1 - y_i) \cdot \text{log}(1 - x_i)$) is for fake values.

The discriminator's goal is to minimize the BCE loss function, because that means it's classifying things correctly. So its loss function is:

$$\text{min} \left[ \; y_i \cdot \text{log}(x_i) \; + \; (1 - y_i) \cdot \text{log}(1 - x_i) \; \right]$$

OR

$$\min\limits_{d} \quad - \left[ \mathbb{E}(\text{log}(d(x))) \; + \; \mathbb{E}(1 - \text{log}(d(g(z)))) \right]$$


While, the generator's goal is to maximize the BCE loss function, because that means the discriminator is doing poorly and is classifying it's fake values into reals. The generator only has control over the fake values, so its loss function is:

$$\text{max} \left[ \; (1 - y_i) \cdot \text{log}(1 - x_i) \; \right]$$

OR

$$\max\limits_{g} \quad - \left[ \mathbb{E}(1 - \text{log}(d(g(z)))) \right]$$

> The big picture of this minimax game.

This maximization and minimization is often called a minimax game, at the end of which the goal of the entire GAN architecure is bring the generated data distribution and the real data distribution as close as possible. In other words, during the entire training process, the discriminator is trying to precisely classify which values are fake, and which are real, while the generator is trying to generate fake values that are as close to real as possible.

> Roles of the generator and discriminator.

However, let's take a step back again to the roles of the generator and discriminator. The discriminator needs to output just a single value prediction within zero and one, ie, whether the value it sees is fake or real. Whereas the generator actually needs to produce the realistic-looking fake value, which is a pretty complex output composed of multiple features to try and fool the discriminator, for example, an image. As a result, the discriminator's job tends to be a little bit easier. In other words, it's much easier to look at images in a museum than to actually paint those masterpieces. So, during training, it's quite common for the discriminator to outperform the generator, ie, the discriminator quickly learns to differentiate between real and fake values, and the generator can't keep up. This is where the problem arises.

> Rise of the (vanishing gradient) problem with the BCE loss function.

The fact that the discriminator quickly learns to outperform the generator isn't such a big problem at the beginning of training, because the discriminator isn't that good. So, initially, the discriminator is able to give useful feedback to the generator in the form of a non-zero gradient. However, as it gets better at training, it ability to distinguish between real and fake values becomes much more precise. As a result, as the discriminator keeps getting better, it starts giving less informative feedback to the generator. In fact, the discriminator might give gradients closer to zero, and that becomes unhelpful for the generator because then the generator doesn't know how to improve. This is how the vanishing gradient problem will arise.

#### **Earth Mover's Distance (EMD)**

<!-- - EMD is one of the loss functions for GANs. -->
- EMD is a concept and not a loss function.
- It measures how different the generated data distribution and the real data distribution are. 
- It does so by estimating the amount of effort it takes to make the generated distribution equal to the real.
- Intuitively, if the generate distribution was a pile of dirt, EMD measures how difficult it would be to move that pile of dirt and mold it into the shape and location of the real distribution.
- EMD depends on both the distance and the amount that the generated distribution needs to be moved.
- Unlike the BCE loss function, there gradient values are not confined between 0 and 1, ie, in EMD, gradients can be any value (from $-\infty$ to $\infty$), which is mitigates the vanishing gradient problem.  
- GANs trained using EMD generally outperform the ones trained using the BCE loss function.


#### **Wasserstein Loss**

> Introduction to the Wasserstein loss.

The Wasserstein loss (aka W-loss) is an alternative loss function choice that approximates the Earth Mover's distance as follows:

$$\text{W-Loss} \;\;  = \;\; \min\limits_{g} \max\limits_{c} \; \left[ \mathbb{E}(c(x)) \; - \; \mathbb{E}(c(g(z))) \right]$$

, where:

- $g$ is the generator
- $c$ is the critic (aka the discriminator, but we won't it call it that because it's not a binary classifier, ie, now it doesn't classify values as real or fake, but rather assigns a score to them)
- $x$ is the real values
- $z$ is the noise vector

So, the critic (earlier called discriminator) is trying to maximize the distance between what it thinks are real images and what it thinks are fake images. In other words, it is trying to push away these two distributions to be as far apart as possible. Meanwhile, the generator is trying to minimize this distance, because it wants the critic to think that its fake images are as close as possible to the real images.

> The last layer of the critic's neural network.

When we use the BCE Loss, the output of the discriminator needs to be a prediction between 0 and 1, ie, whether the value is real or fake. And so the discriminator's neural network for GANs, trained with BCE Loss, needs to have a sigmoid activation function in the output layer to then squash the values between 0 and 1. On the contrary, W-Loss doesn't have that requirement at all, so you can actually have a linear layer at the end of the discriminator's neural network (now called critic), and that could produce any real value output. And you can interpret that output as how real an image is considered by the critic.

> Advantages of critic over discriminator.

What's common between the discriminator and critic is that they both want to maximize the difference between the expected values of the predictions for real and fake.

Both critic (W-Loss) and discriminator (BCE loss) measure the distance between the fake and the real distributions. However, the discriminator is bounded between 0 and 1 (ie, it outputs that the image it is fed is either fake or real, and not how much fake or how much real), whereas the critic is no longer bounded ,and just trying to separate the two distributions as much as possible. And as a result, because critic is not bounded, it is allowed to improve without degrading its feedback back to the generator. This is how the critic helps to mitigate the vanishing gradient problem.


#### **Condition on Wasserstein Critic**

> The 1-Lipschitz Continuous condition.

Up until now, we have seen that WGAN can solve problems like mode collapse and vanishing gradient problem faced by the simple GAN and the DCGAN models. However, in order for it to work well, there is one special condition that needs to be met by the Critic model. This condition is called 1-Lipschitz Continuous (or 1-L Continuous, for short). For a function like the Critic model to be 1-Lipschitz Continuous, the norm of its gradient needs to be at most one for every point. This means that the slope can't be greater than one at any point, ie, its gradient can't be greater than one. The big picture is that the gradients of the Critic model should not change very rapidly. 

> How to mathematically check if a function is 1-Lipschitz Continuous.

It's a bit difficult to describe in words (tbh, I'm just too lazy to plot the graphs and describe the whole thing to you ;)) how to mathematically check if a function is 1-Lipschitz Continuous as it involves graphing the function in consideration, then drawing 2 lines, one with slope 1 and and the other with slope -1, and checking if the graph of the function lies within those 2 lines at every point. If it does, then the function is 1-Lipschitz Continuous, otherwise it's not. So, for this, I highly recommend watching the video titled "Condition on Wasserstein Critic" in Course 1 Week 3 of the [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans).

> Benefits of the 1-Lipschitz Continuous condition.

This condition on the Critic model is important for W-Loss because it assures that the W-Loss function is not only continuous and differentiable, but also that it doesn't grow too much and maintain some stability during training. This is required for training both the Critic and Generator models and it also increases stability because the variation as the GAN learns will be bounded. 

#### **1-Lipschitz Continuity Enforcement**

something

### Alright enough theory, let's write some code now!

#### Loading the data

In [12]:
class SomeDataClass:
    def __init__(self):
        pass

    def train_dataloader(self):
        pass

    def val_dataloader(self):
        pass


#### Model Architectures and Hyperparameters

In [13]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, img_channel=1, hidden_dim=64):
        """
        The Generator class.

        Parameters:
            - z_dim: the dimension of the noise vector, a scalar
            - img_channel: the number of channels of the output image, a scalar
                (MNIST is grayscale, so default value is img_channel=1)
            - hidden_dim: the inner dimension, a scalar
        """
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.gen_block(z_dim, hidden_dim*4),
            self.gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
            self.gen_block(hidden_dim*2, hidden_dim),
            self.gen_block(hidden_dim, img_channel, kernel_size=4, final_layer=True)
        )

    def gen_block(self, in_channel, out_channel, kernel_size=3, stride=2, 
                  final_layer=False):
        """
        Returns the layers of a generator block.

        Parameters:
        - in_channel: the number of channels in the input, a scalar
        - out_channel: the number of channels in the output, a scalar
        - kernel_size: the size of the kernel, a scalar
        - stride: the stride of the kernel, a scalar
        - final_layer: a boolean, True if this is the final layer and False otherwise
        """
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, 
                                   kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, 
                                   kernel_size=kernel_size, stride=stride),
                nn.Tanh()
            )
    
    def forward(self, noise):
        """
        Given a noise tensor, returns the generated image.
        """
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)


Notice that there is no change in the architecture of the Generator of WGAN model as compared to that of the DCGAN model.

In [14]:
class Critic(nn.Module):
    def __init__(self, img_channel=1, hidden_dim=64):
        """
        The Critic class.

        Parameters:
        - img_channel: the number of channels of the input image, a scalar
            (MNIST is grayscale, so default value is img_channel=1)
        - hidden_dim: the inner dimension, a scalar
        """
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            self.critic_block(img_channel, hidden_dim),
            self.critic_block(hidden_dim, hidden_dim*2),
            self.critic_block(hidden_dim*2, 1, final_layer=True)
        )

    def critic_block(self, in_channel, out_channel, kernel_size=4, stride=2,
                   final_layer=False):
          """
          Returns the layers of a critic block.
    
          Parameters:
          - in_channel: the number of channels in the input, a scalar
          - out_channel: the number of channels in the output, a scalar
          - kernel_size: the size of the kernel, a scalar
          - stride: the stride of the kernel, a scalar
          - final_layer: a boolean, True if this is the final layer and False otherwise
          """
          if not final_layer:
                return nn.Sequential(
                    nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, 
                            stride=stride),
                    nn.BatchNorm2d(out_channel),
                    nn.LeakyReLU(0.2, inplace=True)
                )
          else:
                return nn.Sequential(
                    nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, 
                            stride=stride)
                )

    def forward(self, image):
        """
        Given an image tensor, returns a 1-dimension tensor 
        representing fake/real.
        Parameters:
            image: a flattened image tensor
        """
        critic_pred = self.critic(image)
        return critic_pred.view(len(critic_pred), -1)


In [15]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-4           
z_dim = 64          # latent noise dimension
img_dim = 1         # 1 means grayscale image
batch_size = 128
num_epochs = 50
display_step = 500   # after how many steps to display loss

# These parameters control the Adam optimizer's momentum:
# https://distill.pub/2017/momentum/
beta_1 = 0.5 
beta_2 = 0.999

c_lambda = 10         # weight of the gradient penalty
crit_repeats = 5      # number of times to update the critic per generator update

Notice that the architectures of the Critic (of WGAN model) and the Discriminator (of the DCGAN model) are almost the same. The only difference is that the Critic has a linear output layer, whereas the Discriminator has a sigmoid output layer.

#### Training loop and loss functions

In [None]:
def get_gen_loss():
    pass

In [4]:
def get_crit_loss():
    pass

In [5]:
def train_step():
    pass

In [7]:
def val_step():
    pass

In [8]:
def training_fn():
    pass

#### Saving the models' weights and Inferring on the test set