## Wasserstein Generative Adversarial Network (WGAN)

Papers: 
- [Wasserstein GAN](https://arxiv.org/pdf/1701.07875)
- [Improved Training of Wasserstein GANs](https://arxiv.org/pdf/1704.00028)

Helpful Resources:
- [Aladdin Persson's playlist on GANs](https://youtube.com/playlist?list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&si=8ooImkbbXhCUC1xB)
- [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans)  $\; \rightarrow \;$ most of the content of this notebook has been borrowed from this course
- [Stanford's Deep Generative Models playlist](https://youtube.com/playlist?list=PLoROMvodv4rPOWA-omMM6STXaWW4FvJT8&si=N_TpTe1bPIhte-t8)
- [From GAN to WGAN](https://lilianweng.github.io/posts/2017-08-20-gan/)
- [Read-through: Wasserstein GAN](https://www.alexirpan.com/2017/02/22/wasserstein-gan.html)

This notebook just includes the implementation of the WGAN model and its training loop. The results are not shown here.

Feel free to check the results on my Kaggle notebook: https://www.kaggle.com/code/aryamanbansal/wgan-gp

#### Imports and some helpful utility functions

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torchinfo import summary

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import glob
from datetime import datetime
import pytz
import time
import os
from tqdm import tqdm
import random
import gc

%matplotlib inline

print("Imports done!")

Imports done!


In [27]:
def get_torch_version():
    """
    Returns the version of PyTorch installed.
    """
    torch_version = torch.__version__.split("+")[0]
    torch_number = torch_version.split(".")[:2]
    torch_number_float = torch_number[0] + "." + torch_number[1]
    torch_number_float = float(torch_number_float)
    return torch_number_float


def set_seed(seed: int = 42):
    """
    Seeds basic parameters for reproducibility of results
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        if get_torch_version() <= 1.7:
            torch.set_deterministic(True)
        else:
            torch.use_deterministic_algorithms(True)


def set_scheduler(scheduler, results, scheduler_on):
    """
    Makes the neccessary updates to the scheduler.
    
    Parameters:
    - scheduler: torch.optim.lr_scheduler, the scheduler to update.
    - results: dict, the results dictionary containing the training and test metric values.
        Keys in the results dictionary are: "gen_train_loss", "gen_val_loss", 
                                            "disc_train_loss", "disc_val_loss".
    - scheduler_on: str, the metric to use for the scheduler update.
    """
    if scheduler_on == "gen_val_loss":
        scheduler.step(results["gen_val_loss"][-1])
    if scheduler_on == "disc_val_loss":
        scheduler.step(results["disc_val_loss"][-1])
    elif scheduler_on == "gen_train_loss":
        scheduler.step(results["gen_train_loss"][-1])
    elif scheduler_on == "disc_train_loss":
        scheduler.step(results["disc_train_loss"][-1])
    else:
        raise ValueError("Invalid `scheduler_on` choice.")
    return scheduler


def save_model_info(path: str, model, model_name, optimizer, optimizer_name, 
                    scheduler=None, scheduler_name=""):
    """
    Saves the model and optimizer weights to the specified path.

    Parameters:
    - path: str, the path to the directory to save the model and optimizer weights.
    - model: nn.Module, the model to save.
    - model_name: str, the name of the model weights file.
    - optimizer: torch.optim, the optimizer to save.
    - optimizer_name: str, the name of the optimizer weights file.
    - scheduler: torch.optim.lr_scheduler, the scheduler to save.
    - scheduler_name: str, the name of the scheduler weights file.

    Free advice:
    A good practice is to transfer the model to the CPU before calling torch.save as this 
    will save tensors as CPU tensors and not as CUDA tensors. This will help in loading
    the model onto any machine, whether it contains CUDA capabilities or not.
    """
    model.to("cpu")
    torch.save(model.state_dict(), os.path.join(path,model_name))
    torch.save(optimizer.state_dict(), os.path.join(path,optimizer_name))
    if scheduler is not None:
        torch.save(scheduler.state_dict(), os.path.join(path,scheduler_name))    
    print("Model info saved!")
    
    
def load_model_info(path, device, model, model_name, optimizer, optimizer_name, 
                    scheduler=None, scheduler_name=""):
    """
    Loads the model and optimizer weights from the specified path.

    Parameters:
    - path: str, the path to the directory containing the model and optimizer weights.
    - device: str, the device to load the model and optimizer weights onto.
    - model: nn.Module, the model to load the weights into.
    - model_name: str, the name of the model weights file.
    - optimizer: torch.optim, the optimizer to load the weights into.
    - optimizer_name: str, the name of the optimizer weights file.
    - scheduler: torch.optim.lr_scheduler, the scheduler to load the weights into.
    - scheduler_name: str, the name of the scheduler weights file.
    """
    model.load_state_dict(torch.load(os.path.join(path,model_name)))
    model.to(device)
    optimizer.load_state_dict(torch.load(os.path.join(path,optimizer_name)))
    if scheduler is not None:
        scheduler.load_state_dict(torch.load(os.path.join(path,scheduler_name)))
    print("Model info loaded!")
    
    
def get_current_time():
    """Returns the current time in Toronto."""
    now = datetime.now(pytz.timezone('Canada/Eastern'))
    current_time = now.strftime("%d_%m_%Y__%H_%M_%S")
    return current_time


def show_tensor_images(image_tensor, num_images=25, size=(1,28,28)):
    """
    Function for visualizing images: Given a tensor of images, 
    number of images, and size per image, plots and prints the 
    images in an uniform grid.
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()


def visualize_results(results):
    """
    Plots the train and test losses of the generator and the critic.

    results is dictionary with keys: 
        "gen_train_loss", "gen_val_loss", "crit_train_loss", "crit_val_loss".
    """
    gen_train_loss = results["gen_train_loss"]
    gen_val_loss = results["gen_val_loss"]
    crit_train_loss = results["crit_train_loss"]
    crit_val_loss = results["crit_val_loss"]
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,4))
    ax[0].set_xlabel("Epochs")
    ax[0].set_ylabel("Training Loss")
    ax[0].plot(gen_train_loss, label="generator", color="orange")
    ax[0].plot(crit_train_loss, label="critic", color="blue")
    ax[0].legend()
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("Validation Loss")
    ax[1].plot(gen_val_loss, label="generator", color="orange")
    ax[1].plot(crit_val_loss, label="critic", color="blue")
    ax[1].legend()
    plt.show()


print("Utility functions created!")

Utility functions created!


## Some theory about WGANs

Before we jump in to the code, let's understand the theory behind WGANs, and why we need them. 

Note: the following notes are heavily inspired by the [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans).


#### **Motivation**

In simple GANs using binary cross entropy (BCE) loss, the generator tries to make the discriminator think its fake images are real. The issue arises because the generator's learning depends entirely on how the discriminator evaluates its images.

When the discriminator strongly classifies an image as real (assigns higher probability to that image as real), the generator focuses heavily on producing that image repeatedly, leading to a lack of variety. Conversely, if the discriminator strongly classifies the image as fake, the generator receives almost no useful feedback, as BCE loss gives a vanishing gradient when predictions are very confident. This stalling of the generator’s learning leads to the bad scenario where the generator is stuck with no clear way to improve.

This phenomenon occurs because BCE loss measures how well the discriminator classifies an image but doesn't give the generator useful signals when its outputs are far from realistic.

#### **Mode Collapse**

- **What is it**: Mode collapse happens when the generator produces limited types of images (e.g., just one kind of dog) instead of a diverse set of realistic outputs.

- **When does it happen**: Mode collapse occurs when the generator gets stuck generating one mode. The discriminator will eventually learn to differentiate the generator's fakes when this happens and outskill it, ending the model's learning. In other words, mode collapse often occurs when the generator "overfits" to fool the discriminator by creating only a few images that it knows will work well, ignoring the rest of the data distribution. It generally occurs when we use the BCE loss function to train GANs.

- **Why does it happen**: This happens because the discriminator doesn't penalize the generator for repeatedly producing the same outputs, as long as they seem realistic. The generator's objective doesn’t explicitly encourage diversity.

#### **Problem with BCE Loss function**

**TL;DR**: The BCE loss function is not a good choice for GANs because it can lead to vanishing gradient problem and mode collapse. In other words, the discriminator gets too good at distinguishing between real and fake values, and as a result, the generator doesn't get useful feedback to improve.

> Binary Cross Entropy (BCE) loss function review. 

The BCE loss function is defined as:

$$\text{BCELoss} = -\frac{1}{m} \sum_{i=1}^m \left[ y_i \cdot \text{log}(x_i) + (1 - y_i) \cdot \text{log}(1 - x_i) \right]$$

A simplified form of the above loss function is:

$$\text{BCELoss} \quad = \quad \min\limits_{d} \max\limits_{g} \quad - \left[ \mathbb{E}(\text{log}(d(x))) \; + \; \mathbb{E}(1 - \text{log}(d(g(z)))) \right]$$

, where:

- $d$ is the discriminator
- $g$ is the generator
- $x$ is the real values
- $z$ is the noise vector
- $d(x)$ is the discriminator's prediction for real values
- $g(z)$ is the fake values generated by the generator
- $d(g(z))$ is the discriminator's prediction for fake values


The first term ($y_i \cdot \text{log}(x_i)$) is for real values, and the second term ($(1 - y_i) \cdot \text{log}(1 - x_i)$) is for fake values.

The discriminator's goal is to minimize the BCE loss function, because that means it's classifying things correctly. So its loss function is:

$$\text{min} \left[ \; y_i \cdot \text{log}(x_i) \; + \; (1 - y_i) \cdot \text{log}(1 - x_i) \; \right]$$

OR

$$\min\limits_{d} \quad - \left[ \mathbb{E}(\text{log}(d(x))) \; + \; \mathbb{E}(1 - \text{log}(d(g(z)))) \right]$$


While, the generator's goal is to maximize the BCE loss function, because that means the discriminator is doing poorly and is classifying it's fake values into reals. The generator only has control over the fake values, so its loss function is:

$$\text{max} \left[ \; (1 - y_i) \cdot \text{log}(1 - x_i) \; \right]$$

OR

$$\max\limits_{g} \quad - \left[ \mathbb{E}(1 - \text{log}(d(g(z)))) \right]$$

> The big picture of this minimax game.

This maximization and minimization is often called a minimax game, at the end of which the goal of the entire GAN architecure is bring the generated data distribution and the real data distribution as close as possible. In other words, during the entire training process, the discriminator is trying to precisely classify which values are fake, and which are real, while the generator is trying to generate fake values that are as close to real as possible.

> Roles of the generator and discriminator.

However, let's take a step back again to the roles of the generator and discriminator. The discriminator needs to output just a single value prediction within zero and one, ie, whether the value it sees is fake or real. Whereas the generator actually needs to produce the realistic-looking fake value, which is a pretty complex output composed of multiple features to try and fool the discriminator, for example, an image. As a result, the discriminator's job tends to be a little bit easier. In other words, it's much easier to look at images in a museum than to actually paint those masterpieces. So, during training, it's quite common for the discriminator to outperform the generator, ie, the discriminator quickly learns to differentiate between real and fake values, and the generator can't keep up. This is where the problem arises.

> Rise of the (vanishing gradient) problem with the BCE loss function.

The fact that the discriminator quickly learns to outperform the generator isn't such a big problem at the beginning of training, because the discriminator isn't that good. So, initially, the discriminator is able to give useful feedback to the generator in the form of a non-zero gradient. However, as it gets better at training, it ability to distinguish between real and fake values becomes much more precise. As a result, as the discriminator keeps getting better, it starts giving less informative feedback to the generator. In fact, the discriminator might give gradients closer to zero, and that becomes unhelpful for the generator because then the generator doesn't know how to improve. This is how the vanishing gradient problem will arise.

#### **Earth Mover's Distance (EMD)**

<!-- - EMD is one of the loss functions for GANs. -->
- EMD is a concept and not a loss function.
- It measures how different the generated data distribution and the real data distribution are. 
- It does so by estimating the amount of effort it takes to make the generated distribution equal to the real.
- Intuitively, if the generate distribution was a pile of dirt, EMD measures how difficult it would be to move that pile of dirt and mold it into the shape and location of the real distribution.
- EMD depends on both the distance and the amount that the generated distribution needs to be moved.
- Unlike the BCE loss function, there gradient values are not confined between 0 and 1, ie, in EMD, gradients can be any value (from $-\infty$ to $\infty$), which is mitigates the vanishing gradient problem.  
- GANs trained using EMD generally outperform the ones trained using the BCE loss function.


#### **Wasserstein Loss**

> Introduction to the Wasserstein loss.

The Wasserstein loss (aka W-loss) is an alternative loss function choice that approximates the Earth Mover's distance as follows:

$$\text{W-Loss} \;\;  = \;\; \min\limits_{g} \max\limits_{c} \; \left[ \mathbb{E}(c(x)) \; - \; \mathbb{E}(c(g(z))) \right]$$

, where:

- $g$ is the generator
- $c$ is the critic (aka the discriminator, but we won't it call it that because it's not a binary classifier, ie, now it doesn't classify values as real or fake, but rather assigns a score to them)
- $x$ is the real values
- $z$ is the noise vector

So, the critic (earlier called discriminator) is trying to maximize the distance between what it thinks are real images and what it thinks are fake images. In other words, it is trying to push away these two distributions to be as far apart as possible. Meanwhile, the generator is trying to minimize this distance, because it wants the critic to think that its fake images are as close as possible to the real images.

> The last layer of the critic's neural network.

When we use the BCE Loss, the output of the discriminator needs to be a prediction between 0 and 1, ie, whether the value is real or fake. And so the discriminator's neural network for GANs, trained with BCE Loss, needs to have a sigmoid activation function in the output layer to then squash the values between 0 and 1. On the contrary, W-Loss doesn't have that requirement at all, so you can actually have a linear layer at the end of the discriminator's neural network (now called critic), and that could produce any real value output. And you can interpret that output as how real an image is considered by the critic.

> Advantages of critic over discriminator.

What's common between the discriminator and critic is that they both want to maximize the difference between the expected values of the predictions for real and fake.

Both critic (W-Loss) and discriminator (BCE loss) measure the distance between the fake and the real distributions. However, the discriminator is bounded between 0 and 1 (ie, it outputs that the image it is fed is either fake or real, and not how much fake or how much real), whereas the critic is no longer bounded ,and just trying to separate the two distributions as much as possible. And as a result, because critic is not bounded, it is allowed to improve without degrading its feedback back to the generator. This is how the critic helps to mitigate the vanishing gradient problem.


#### **Condition on Wasserstein Critic**

> The 1-Lipschitz Continuous condition.

Up until now, we have seen that WGAN can solve problems like mode collapse and vanishing gradient problem faced by the simple GAN and the DCGAN models. However, in order for it to work well, there is one special condition that needs to be met by the Critic model. This condition is called 1-Lipschitz Continuous (or 1-L Continuous, for short). For a function like the Critic model to be 1-Lipschitz Continuous, the norm of its gradient needs to be at most one at every single point of this function. Mathematically, we can write this as:

$$\left\| \nabla c(x) \right\|_2 \; \leq \; 1, \quad \forall x \in D$$

, where:
- $c$ is the function (ie, the Critic model)
- $x$ is the input to the function (ie, the image)
- $D$ is the domain of the function
- $\left\| \cdot \right\|_2$ is the L2 norm, which represents the Euclidean distance

This means that the slope can't be greater than one at any point, ie, its gradient can't be greater than one. The big picture is that the gradients of the Critic model should not change very rapidly. 

> How to mathematically check if a function is 1-Lipschitz Continuous.

It's a bit difficult to describe in words (tbh, I'm just too lazy to plot the graphs and describe the whole thing to you ;)) how to mathematically check if a function is 1-Lipschitz Continuous as it involves graphing the function in consideration, then drawing 2 lines, one with slope 1 and and the other with slope -1, and checking if the graph of the function lies within those 2 lines at every point. If it does, then the function is 1-Lipschitz Continuous, otherwise it's not. So, for this, I highly recommend watching the video titled "Condition on Wasserstein Critic" in Course 1 Week 3 of the [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans).

> Benefits of the 1-Lipschitz Continuous condition.

This condition on the Critic model is important for W-Loss because it assures that the W-Loss function is not only continuous and differentiable, but also that it doesn't grow too much and maintain some stability during training. This is required for training both the Critic and Generator models and it also increases stability because the variation as the GAN learns will be bounded. 

#### **1-Lipschitz Continuity Enforcement**

> Big picture.

Just to recap, 1-L continuity in the Critic model ensures that W-loss is valid. There are 2 methods to enforce this condition:
1. Weight clipping
2. Gradient Penalty

However, gradient penalty is a better method than weight clipping for enforcing 1-L continuity.

> Weight clipping.

- Weight clipping forces the weights of the Critic model to be within a fixed interval, say, $[-0.01, \; 0.01]$. After we update the weights using gradient descent, we will clip any weights outside of the desired interval $[-0.01, \; 0.01]$. This means that weights outside of our chosen interval will be set to the maximum or minimum amount allowed. In our case, weights greater than $0.01$ will be set to $0.01$, and weights less than $-0.01$ will be set to $-0.01$.
- This method has some downsides:
    - Forcing the weights of the critic to a limited range of values could limit the critics ability to learn and ultimately for the GAN to perform because if the critic can't take on many different parameter values, ie, if its weights can't take on many different values, then it might not be able to improve easily or find good global optima for it to be in (Note: Global optima refers to the set of best possible solutions in an optimization problem, while global maxima specifically refers to the highest value point in the entire domain of possible solutions. Global optima includes global maxima.)
    - On the other hand, weight clipping might not limit the critic enough, ie, it might not be able to enforce the 1-Lipschitz continuity condition well enough.
    - Choosing the interval for weight clipping is a hyperparameter that needs to be tuned, and it's not always easy to find the right interval that works well for the critic.

> Gradient Penalty.

With the gradient penalty, all we need to do is add a regularization term ($\text{reg}$) multiplied by a hyperparameter value ($\lambda$) to our loss function. Mathematically, we can write this as:

$$\text{W-Loss} \quad  = \quad \min\limits_{g} \max\limits_{c} \; \left[ \mathbb{E}(c(x)) \; - \; \mathbb{E}(c(g(z))) \right] \;\; + \;\; \lambda \cdot \text{reg} $$

, where:
- $\text{reg}$ is the regularization term
- $\lambda$ is the hyperparameter value that tells how much to weigh this regularization term $\text{reg}$ against the main loss function.

What this regularization term ($\text{reg}$) does to our W-loss function, is that it penalizes the critic when its gradient norm is higher than one. Now, it's virtually impossible (or at least not practical) to check the gradient of the critic at every possible point in the feature space. So, we sample some points (or images in our case) by interpolating between the real and fake images, and check the gradient at those points. It's a bit difficult to describe in words (again, classic me being lazy) how to get these interpolated images, but the idea is to first choose some value $\varepsilon$ (say, 0.3) to do interpolation between the real and fake images. This $\varepsilon$ acts as a weight to the real image, and $1 - \varepsilon$ (which is 0.7) acts as a weight to the fake image. So, the interpolated image $\hat{x}$ is an image between the real and the fake images. Mathematically, we can write this as:

$$\hat{x} \;\; = \;\; \varepsilon \cdot x \; + \; (1 - \varepsilon) \cdot g(z)$$

, where:
- $\hat{x}$ is the interpolated image
- $x$ is the real image
- $g$ is the generator
- $z$ is the noise vector
- $g(z)$ is the fake image
- $\varepsilon$ is the weight for the real image

For a little visual flavor of the above, I highly recommend watching the video titled "1-Lipschitz Continuity Enforcement" in Course 1 Week 3 of the [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans).

Okay, so it's on $\hat{x}$ that we want to get the critic's gradients to be less than or equal to one, ie, 

$$\left\| \nabla c(\hat{x}) \right\|_2 \; \leq \; 1$$


Now, it's much simpler to get the norm of the gradient to be exactly one, as opposed to at most one. This is because the above equation is in fact our regularization term $\text{reg}$. So, we'll incorporate this change, and write the new equation as:

$$\text{reg} \;\; = \;\; \left\| \nabla c(\hat{x}) \right\|_2 \; - \; 1 $$

We're not done yet. There is one more change to be made to the above equation. We'll square the above equation to get the final regularization term. We could have taken the absolute value instead of taking the square, but squaring the equation makes it differentiable at the point where the gradient norm is exactly one. Also, the squaring enables us to penalize the critic more when its gradient norm is further away from one. So, we'll incorporate this change, and write the new equation as:

$$\text{reg} \;\; = \;\; \left( \left\| \nabla c(\hat{x}) \right\|_2 \; - \; 1 \right)^2 $$

The complete W-loss function with the gradient penalty is:

$$\text{W-Loss} \quad  = \quad \min\limits_{g} \max\limits_{c} \; \left[ \mathbb{E}(c(x)) \; - \; \mathbb{E}(c(g(z))) \right] \;\; + \;\; \lambda \cdot \mathbb{E} \left( \left( \left\| \nabla c(\hat{x}) \right\|_2 \; - \; 1 \right)^2 \right)$$

So, the first term in the above equation $\displaystyle \min\limits_{g} \max\limits_{c} \; \left[ \mathbb{E}(c(x)) - \mathbb{E}(c(g(z))) \right]$ approximates Earth Mover's distance, which makes the GAN less prone to mode collapse and vanishing gradients, while the second term $\displaystyle \lambda \cdot \mathbb{E} \left( \left( \left\| \nabla c(\hat{x}) \right\|_2 \; - \; 1 \right)^2 \right)$ meets the condition for what the critic desires in order to make the first term valid. 


We must note one subtle point about the gradient penalty method. With this method, we're not strictly enforcing 1-L continuity, but just encouraging it. This has proven to work well and much better than weight clipping.


## Alright enough theory, let's write some code now!

#### Loading the data

In [17]:
class SomeDataClass:
    def __init__(self, batch_size=128):
        self.batch_size = batch_size
        self.transformations = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
            ])
        self.train_dataset = MNIST(root="dataset/", 
                              transform=self.transformations, 
                              download=True, train=True)
        self.valid_dataset = MNIST(root="dataset/", 
                                  transform=self.transformations, 
                                  download=True, train=False)

    def get_trainloader(self):
        """Returns the training dataloader."""
        return DataLoader(self.train_dataset, batch_size=self.batch_size, 
                          shuffle=True)

    def get_validloader(self):
        """Returns the validation dataloader."""
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, 
                          shuffle=False)
    

In [None]:
def fn():
    data = SomeDataClass()
    train_loader = data.get_trainloader()
    for item in train_loader:
        print(len(item))
        print(item[0].shape, item[1].shape)
        break

fn()

2
torch.Size([128, 1, 28, 28]) torch.Size([128])


#### Model Architectures and Hyperparameters

In [30]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, img_channel=1, hidden_dim=64):
        """
        The Generator class.

        Parameters:
            - z_dim: the dimension of the noise vector, a scalar
            - img_channel: the number of channels of the output image, a scalar
                (MNIST is grayscale, so default value is img_channel=1)
            - hidden_dim: the inner dimension, a scalar
        """
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.gen_block(z_dim, hidden_dim*4),
            self.gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
            self.gen_block(hidden_dim*2, hidden_dim),
            self.gen_block(hidden_dim, img_channel, kernel_size=4, final_layer=True)
        )

    def gen_block(self, in_channel, out_channel, kernel_size=3, stride=2, 
                  final_layer=False):
        """
        Returns the layers of a generator block.

        Parameters:
        - in_channel: the number of channels in the input, a scalar
        - out_channel: the number of channels in the output, a scalar
        - kernel_size: the size of the kernel, a scalar
        - stride: the stride of the kernel, a scalar
        - final_layer: a boolean, True if this is the final layer and False otherwise
        """
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, 
                                   kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, 
                                   kernel_size=kernel_size, stride=stride),
                nn.Tanh()
            )
    
    def forward(self, noise):
        """
        Given a noise tensor, returns the generated image.
        """
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)


Notice that there is no change in the architecture of the Generator of WGAN model as compared to that of the DCGAN model.

In [31]:
class Critic(nn.Module):
    def __init__(self, img_channel=1, hidden_dim=64):
        """
        The Critic class.

        Parameters:
        - img_channel: the number of channels of the input image, a scalar
            (MNIST is grayscale, so default value is img_channel=1)
        - hidden_dim: the inner dimension, a scalar
        """
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            self.critic_block(img_channel, hidden_dim),
            self.critic_block(hidden_dim, hidden_dim*2),
            self.critic_block(hidden_dim*2, 1, final_layer=True)
        )

    def critic_block(self, in_channel, out_channel, kernel_size=4, stride=2,
                   final_layer=False):
          """
          Returns the layers of a critic block.
    
          Parameters:
          - in_channel: the number of channels in the input, a scalar
          - out_channel: the number of channels in the output, a scalar
          - kernel_size: the size of the kernel, a scalar
          - stride: the stride of the kernel, a scalar
          - final_layer: a boolean, True if this is the final layer and False otherwise
          """
          if not final_layer:
                return nn.Sequential(
                    nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, 
                            stride=stride),
                    nn.BatchNorm2d(out_channel),
                    nn.LeakyReLU(0.2, inplace=True)
                )
          else:
                return nn.Sequential(
                    nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, 
                            stride=stride)
                )

    def forward(self, image):
        """
        Given an image tensor, returns a 1-dimension tensor 
        representing fake/real.
        Parameters:
            image: a flattened image tensor
        """
        critic_pred = self.critic(image)
        return critic_pred.view(len(critic_pred), -1)


Notice that the architectures of the Critic (of WGAN model) and the Discriminator (of the DCGAN model) are almost the same. The only difference is that the Critic has a linear output layer, whereas the Discriminator has a sigmoid output layer.

#### Gradient Penalty

Calculating the gradient penalty can be broken into two functions: 
1. compute the gradient wrt the images $\;\; \rightarrow \;\;$ `get_gradient()`
2. compute the gradient penalty given the gradient $\;\; \rightarrow \;\;$ `gradient_penalty()`

Let's see how to code the `get_gradient()` function.

First, the gradient is computed by first creating a mixed image. This is done by weighing the fake and real image using epsilon and then adding them together. The mathematcial representation of this is:

$$\hat{x} \;\; = \;\; \varepsilon \cdot x \; + \; (1 - \varepsilon) \cdot g(z)$$

Once we have the intermediate image, we can get the critic's output on the image, ie,

$$c(\hat{x})$$

Finally, we can compute the gradient of the critic's score on the mixed images (output) with respect to the pixels of the mixed images (input), ie,

$$\nabla c(\hat{x})$$

In [1]:
def get_gradient(critic, real, fake, epsilon):
    """
    Parameters:
        critic: the critic model
        real: a batch of real images
        fake: a batch of fake images (generated by the generator)
        epsilon: a vector of the uniformly random proportions of 
                 real/fake per mixed image
    Returns:
        gradient: the gradient of the critic's scores with respect 
                  to the mixes of real and fake images.
    """
    # Mix the images together
    mixed_images = real*epsilon + fake*(1-epsilon)

    # Calculate the critic's scores on the mixed images
    mixed_scores = critic(mixed_images)
    
    # Take the gradient of the scores wrt the images
    # Note: we need to take the gradient of outputs wrt inputs
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    
    return gradient


Now, let's see how to calculate the `gradient_penalty()` function.

Our goal is to compute the gradient penalty given the gradient. First, we calculate the magnitude of each image's gradient. The magnitude of a gradient is also called the norm. The mathematical representation of this is:

$$\left\| \nabla c(\hat{x}) \right\|_2$$

Then, we calculate the penalty by squaring the distance between each magnitude and the ideal norm of 1 and taking the mean of all the squared distances, ie, 

$$\text{reg} \;\; = \;\; \left( \left\| \nabla c(\hat{x}) \right\|_2 \; - \; 1 \right)^2$$

Finally, we take the mean of all the penalties to get the gradient penalty, ie,

In [2]:
def gradient_penalty(gradient):
    """
    Given a batch of image gradients, you calculate the magnitude of 
    each image's gradient and penalize the mean quadratic distance of 
    each magnitude to 1.
    Parameters:
        gradient: the gradient of the critic's scores, with respect to 
                  the mixed image
    Returns:
        penalty: the gradient penalty
    """
    # Flatten the gradients so that each row captures one image
    gradient = gradient.view(len(gradient), -1)

    # Calculate the magnitude of every row, ie, every image
    gradient_norm = torch.linalg.vector_norm(gradient, ord=2, dim=1)

    # Penalize the mean squared distance of the gradient norms from 1
    penalty = torch.mean((gradient_norm - 1)**2)

    return penalty


#### Loss functions and training loop

For the generator's loss function, the loss is calculated by maximizing the critic's prediction (or scores) on the generator's fake images. The argument `crit_fake_pred` has the scores for all fake images in the batch, but we will use the mean of them. The mathematical representation of what `get_gen_loss` function returns is:

$$- \mathbb{E}(c(g(z)))$$

In [7]:
def get_gen_loss(crit_fake_pred):
    """
    Return the loss of a generator given the critic's scores of 
    the generator's fake images.
    Parameters:
        crit_fake_pred: the critic's scores of the fake images
    Returns:
        gen_loss: a scalar loss value for the current batch of the generator
    """
    # take mean of the critic's prediction on the fake images
    gen_loss = torch.mean(crit_fake_pred)
    # multiply by -1 as we want to maximize this loss, instead of minimizing
    gen_loss = -1. * gen_loss
    return gen_loss


For the critic's loss function, the loss is calculated by maximizing the distance between the critic's predictions on the real images and the predictions on the fake images while also adding a gradient penalty. The gradient penalty is weighed according to lambda. The arguments are the scores for all the images in the batch, and we will use the mean of them. The mathematical representation of what `get_crit_loss` function returns is:

$$\mathbb{E}(c(g(z))) \; - \; \mathbb{E}(c(x)) \;\; + \;\; \lambda \cdot \text{reg}$$

In [8]:
def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    """
    Return the loss of a critic given the critic's scores for fake and 
    real images, the gradient penalty, and gradient penalty weight.
    Parameters:
        crit_fake_pred: the critic's scores of the fake images
        crit_real_pred: the critic's scores of the real images
        gp: the unweighted gradient penalty
        c_lambda: the current weight of the gradient penalty 
    Returns:
        crit_loss: a scalar for the critic's loss, accounting for the 
                   relevant factors
    """
    crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp
    return crit_loss


Before we start working on the training loop, we need to consider a few things:
- Even on GPU, the **training will run more slowly** than that of the DCGAN implementation because the gradient penalty requires you to compute the gradient of a gradient -- this means potentially a few minutes per epoch! For best results, run this for as long as you can while on GPU.
- One important difference compared to simple GAN and DCGAN is that we will **update the critic multiple times** every time we update the generator. This helps prevent the generator from overpowering the critic. Sometimes, we might see the reverse, with the generator updated more times than the critic. This depends on architectural (e.g. the depth and width of the network) and algorithmic choices (e.g. which loss we're using).
- WGAN-GP isn't necessarily meant to improve overall performance of a GAN, but just **increases stability** and avoids mode collapse. In general, a WGAN will be able to train in a much more stable way than the vanilla DCGAN, though it will generally run a bit slower. We should also be able to train our WGAN-GP model for more epochs without it collapsing.

In [15]:
def update_critic(real, device, critic, generator, crit_optim, 
                  z_dim, n_samples, c_lambda, test=False):
    """"
    Updates the critic model.
    Parameters:
        - real: the real images
        - device: the device to train the model on
        - critic: the critic model
        - generator: the generator model
        - z_dim: the dimension of the noise vector
        - n_samples: the number of samples to generate
        - crit_optim: the optimizer for the critic
        - c_lambda: the weight of the gradient penalty
        - test: a boolean, True if the function is called during testing
    Returns: average loss for the critic
    """
    if not test:
        crit_optim.zero_grad()
        # generate the noise vector
        fake_noise = torch.randn(n_samples, z_dim, device=device)
        # generate fake images by sending noise into generator
        fake = generator(fake_noise)
        # get the critic scores for fake images
        crit_fake_pred = critic(fake.detach())
        # get the critic scores for real images
        crit_real_pred = critic(real)
        # get the epsilon for creating mixed images
        epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
        # compute gradient wrt images
        gradient = get_gradient(critic, real, fake.detach(), epsilon)
        # compute gradient penalty given the gradient
        gp = gradient_penalty(gradient)
        # get the critic loss
        crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)
        # Update gradients
        crit_loss.backward(retain_graph=True)
        # Update optimizer
        crit_optim.step()
    else:
        # generate the noise vector
        fake_noise = torch.randn(n_samples, z_dim, device=device)
        # generate fake images by sending noise into generator
        fake = generator(fake_noise)
        # get the critic scores for fake images
        crit_fake_pred = critic(fake.detach())
        # get the critic scores for real images
        crit_real_pred = critic(real)
        # get the epsilon for creating mixed images
        epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
        # compute gradient wrt images
        gradient = get_gradient(critic, real, fake.detach(), epsilon)
        # compute gradient penalty given the gradient
        gp = gradient_penalty(gradient)
        # get the critic loss
        crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)
    return crit_loss.item()


In [16]:
def update_generator(device, critic, generator, gen_optim, 
                     z_dim, n_samples, test=False):
    """
    Parameters:
        - device: the device to train the model on
        - critic: the critic model
        - generator: the generator model
        - gen_optim: the optimizer for the generator
        - z_dim: the dimension of the noise vector
        - n_samples: the number of samples to generate
        - test: a boolean, True if the function is called during testing
    Returns: generator loss
    """
    if not test:
        gen_optim.zero_grad()
        # generate the noise vector
        fake_noise = torch.randn(n_samples, z_dim, device=device)
        # generate fake images by sending noise into generator
        fake = generator(fake_noise)
        # get the critic scores for fake images
        crit_fake_pred = critic(fake)
        # get the generator loss
        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()
        # Update the weights
        gen_optim.step()
    else:
        # generate the noise vector
        fake_noise = torch.randn(n_samples, z_dim, device=device)
        # generate fake images by sending noise into generator
        fake = generator(fake_noise)
        # get the critic scores for fake images
        crit_fake_pred = critic(fake)
        # get the generator loss
        gen_loss = get_gen_loss(crit_fake_pred)
    return gen_loss.item()


In [24]:
def train_step(train_dataloader, critic_repeats, device, critic, generator,
               crit_optim, gen_optim, z_dim, c_lambda):
    """
    Parameters:
        - train_dataloader: the training dataloader
        - critic_repeats: the number of times to update the critic per generator update
        - device: the device to train the model on
        - critic: the critic model
        - generator: the generator model
        - crit_optim: the optimizer for the critic
        - gen_optim: the optimizer for the generator
        - z_dim: the dimension of the noise vector
        - c_lambda: the weight of the gradient penalty
    """
    generator_losses_per_batch = []
    critic_losses_per_batch = []
    generator.train()
    critic.train()
    # We only need the real images, not their labels
    for real, _ in train_dataloader:
        curr_batch_size = len(real)
        real = real.to(device)
        avg_iteration_critic_loss = 0
    
        # Update the critic model critic_repeats number of times 
        for _ in range(critic_repeats):
            avg_iteration_critic_loss += update_critic(real, device, critic, generator, 
                                                        crit_optim, z_dim, curr_batch_size, 
                                                        c_lambda)
        avg_iteration_critic_loss = avg_iteration_critic_loss / critic_repeats
        critic_losses_per_batch.append([avg_iteration_critic_loss])
    
        # Update the generator
        gen_loss = update_generator(device, critic, generator, gen_optim, 
                                    z_dim, curr_batch_size)
        generator_losses_per_batch.append([gen_loss])
    return np.mean(generator_losses_per_batch), np.mean(critic_losses_per_batch)
    

In [25]:
def val_step(valid_dataloader, critic_repeats, device, critic, generator,
             z_dim, c_lambda):
    """
    Parameters:
        - valid_dataloader: the validation dataloader
        - critic_repeats: the number of times to update the critic per generator update
        - device: the device to train the model on
        - critic: the critic model
        - generator: the generator model
    """
    generator_losses_per_batch = []
    critic_losses_per_batch = []
    generator.eval()
    critic.eval()
    # We only need the real images, not their labels
    for real, _ in valid_dataloader:
        curr_batch_size = len(real)
        real = real.to(device)
        avg_iteration_critic_loss = 0
    
        # Update the critic model critic_repeats number of times 
        for _ in range(critic_repeats):
            avg_iteration_critic_loss += update_critic(real=real, device=device, critic=critic, 
                                                       generator=generator, z_dim=z_dim, 
                                                       n_samples=curr_batch_size, 
                                                       c_lambda=c_lambda, test=True)
        avg_iteration_critic_loss = avg_iteration_critic_loss / critic_repeats
        critic_losses_per_batch.append([avg_iteration_critic_loss])
    
        # Update the generator
        gen_loss = update_generator(device=device, critic=critic, generator=generator,  
                                    z_dim=z_dim, n_samples=curr_batch_size, test=True)
        generator_losses_per_batch.append([gen_loss])
    return np.mean(generator_losses_per_batch), np.mean(critic_losses_per_batch)


In [28]:
def training_fn(n_epochs, train_dataloader, valid_dataloader, device, 
                critic_repeats, critic, generator, z_dim, crit_optim, 
                gen_optim, c_lambda):
    """
    Entire training loop with the validation loop.
    Parameters:
        - n_epochs: the number of epochs to train the model
        - train_dataloader: the DataLoader for the training data
        - valid_dataloader: the DataLoader for the validation data
        - device: the device to train the model on
        - critic_repeats: the number of times to update the critic every time 
            the generator is updated
        - critic: the critic model
        - generator: the generator model
        - z_dim: the dimension of the noise vector
        - crit_optim: the optimizer for the critic
        - gen_optim: the optimizer for the generator
        - c_lambda: the weight of the gradient penalty
    """
    results = {
        "gen_train_loss": [],
        "gen_val_loss": [],
        "crit_train_loss": [],
        "crit_val_loss": []
    }
    
    for epoch in tqdm(range(n_epochs)):   
        # Train step
        gen_train_loss, crit_train_loss = train_step(train_dataloader, critic_repeats, 
                                                     device, critic, generator, crit_optim, 
                                                     gen_optim, z_dim, c_lambda)
        results["gen_train_loss"].append(gen_train_loss)
        results["crit_train_loss"].append(crit_train_loss)
            
        # Validation step
        gen_val_loss, crit_val_loss = val_step(valid_dataloader, critic_repeats, 
                                               device, critic, generator, z_dim, 
                                               c_lambda)
        results["gen_val_loss"].append(gen_val_loss) 
        results["crit_val_loss"].append(crit_val_loss)
        
    return results


In [35]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-4           
z_dim = 64            # latent noise dimension
img_dim = 1           # 1 means grayscale image
batch_size = 128      # no. of images in each forward and backward passes
num_epochs = 50

# These parameters control the Adam optimizer's momentum:
# https://distill.pub/2017/momentum/
beta_1 = 0.5 
beta_2 = 0.999

c_lambda = 10         # weight of the gradient penalty
crit_repeats = 5      # number of times to update the critic per generator update

In [36]:
critic = Critic(img_dim).to(device)
generator = Generator(z_dim, img_dim).to(device)

# fixed_noise is the latent noise vector
# torch.randn generates random numbers from a normal distribution
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

# separate optimizers for generator and discriminator
crit_optim = optim.Adam(critic.parameters(), lr=lr, betas=(beta_1, beta_2))
gen_optim = optim.Adam(generator.parameters(), lr=lr, betas=(beta_1, beta_2))

In [37]:
# You initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = generator.apply(weights_init)
crit = critic.apply(weights_init)

In [None]:
data = SomeDataClass()
train_loader = data.get_trainloader()
valid_loader = data.get_validloader()

In [10]:
results = training_fn(num_epochs, train_loader, valid_loader, device, 
                      crit_repeats, critic, generator, z_dim, crit_optim,
                      gen_optim, c_lambda)

In [None]:
visualize_results(results)

#### Saving the models' weights and Inferring on the test set

In [None]:
curr_time = get_current_time()
curr_time

In [None]:
path = os.path.join(os.getcwd(), curr_time)
os.makedirs(path, exist_ok=True)

print(path)

In [39]:
save_model_info(path, model=critic, model_name="crit1.pt", optimizer=crit_optim, 
                optimizer_name="crit_optim1.pt")

In [None]:
save_model_info(path, model=generator, model_name="gen1.pt", optimizer=gen_optim, 
                optimizer_name="gen_optim1.pt")

In [41]:
path = "something"
loaded_gen1 = Generator(z_dim, img_dim).to(device)
loaded_gen_optim1 = optim.Adam(loaded_gen1.parameters(), lr=lr, betas=(beta_1, beta_2))
load_model_info(path, device, model=loaded_gen1, model_name="gen1.pt", 
                optimizer=loaded_gen_optim1, optimizer_name="gen_optim1.pt")

In [None]:
path = "something"
loaded_crit1 = Critic(img_dim).to(device)
loaded_crit_optim1 = optim.Adam(loaded_crit1.parameters(), lr=lr, betas=(beta_1, beta_2))
load_model_info(path, device, model=loaded_crit1, model_name="crit1.pt", 
                optimizer=loaded_crit_optim1, optimizer_name="crit_optim1.pt")