## Generative Adversarial Network (GAN)

Paper: [Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661)

Helpful Resources:
- [Aladdin Persson's playlist on GANs](https://youtube.com/playlist?list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&si=8ooImkbbXhCUC1xB)
- [GANs specialization on coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans)
- [Stanford's Deep Generative Models playlist](https://youtube.com/playlist?list=PLoROMvodv4rPOWA-omMM6STXaWW4FvJT8&si=N_TpTe1bPIhte-t8)
- [AssemblyAI's GAN tutorial](https://youtu.be/_pIMdDWK5sc?si=Mtx2oWh1ZO9tqWYg)
- [The Math Behind Generative Adversarial Networks Clearly Explained! - Normalized Nerd](https://youtu.be/Gib_kiXgnvA?si=wi7mSBZ7uUCsWBn6)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

torch.manual_seed(0)

print("Imports done!")

Imports done!


In [29]:
def plot_images(img_tensor, num_imgs=25, size=(1,28,28)):
    """
    Given a tensor of images, number of images, and size per image, 
    this function plots and prints the images in a uniform grid.
    """
    img_unflat = img_tensor.detach().cpu().view(-1, *size)
    img_grid = make_grid(img_unflat[:num_imgs], nrow=5)
    plt.imshow(img_grid.permute(1,2,0).squeeze())
    plt.show()


In [32]:
def plot_results(results):
    """
    results is dictionary with keys: "gen_train_loss", "gen_test_loss", 
        "disc_train_loss", "disc_test_loss", "gen_train_acc", "gen_test_acc", 
        "disc_train_acc", "disc_test_acc".
    This function plots the train and test losses and accuracies.

    However, for now, we'll only plot the train losses for the generator and discriminator.
    """
    plt.plot(results["gen_train_loss"], label="Generator train loss")
    plt.plot(results["disc_train_loss"], label="Discriminator train loss")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()
    

In [12]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        """
        - param img_dim: dimension of image (eg: 28x28x1 = 784 for 
            grayscale MNIST images)
        """
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid()   # output between 0 and 1
        )
    
    def forward(self, x):
        return self.disc(x)
    

**Why use leaky ReLU instead of ReLU?**

We use leaky ReLU to prevent the "dying ReLU" problem, which refers to the phenomenon where the parameters stop changing due to consistently negative values passed to a ReLU, which result in a zero gradient.

In [13]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        """
        - param z_dim: dimension of latent noise vector
        - param img_dim: dimension of image (eg: 28x28x1 = 784 for 
            grayscale MNIST images)
        """
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh()   # output between -1 and 1
        )

    def forward(self, x):
        return self.gen(x)
    

In [27]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4           # Karpathy constant
z_dim = 64          # latent noise dimension
img_dim = 28*28*1   # 1 means grayscale image
batch_size = 32
num_epochs = 50
display_step = 500   # after how many steps to display loss

In [None]:
transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# we can use the actual mean and std of the MNIST dataset, ie,
# transforms.Normalize((0.1307,), (0.3081,))
# but these don't help in model convergence

train_dataset = MNIST(root="dataset/", transform=transformations, download=True, train=True)
test_dataset = MNIST(root="dataset/", transform=transformations, download=True, train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def fn():
    for item in train_loader:
        print(len(item))
        print(item[0].shape, item[1].shape)
        break

fn()

In [15]:
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

# fixed_noise is the latent noise vector
# torch.randn generates random numbers from a normal distribution
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

# separate optimizers for generator and discriminator
optim_disc = optim.Adam(disc.parameters(), lr=lr)
optim_gen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()  # binary cross entropy loss

**Importance of the (latent) noise vector**

The noise vector is a key component of GANs. It is a random vector that is used as input to the generator. The generator uses this noise vector to generate fake images. The noise vector is important because it allows the generator to generate a wide variety of images. If the noise vector was not used, the generator would only be able to generate a single image.

To create the noise vector, we use `torch.randn` to sample random numbers from the normal distribution, ie, 

`torch.randn((batch_size, z_dim)).to(device)`.

Recall: 

The training of the discriminator was to ***maximize*** the following:

$$\text{log}(D(\text{real\_img})) \; + \; \text{log}(1 - D(G(z)))$$

, where:

- $D$ is the discriminator
- $G$ is the generator
- $z$ is the latent noise vector
- $\text{real\_img}$ is the real image
- $G(z)$ is the image generated by the generator

**A note about using BCELoss** 

The formula for BCELoss in PyTorch is:

$$\text{BCELoss} = -w_n [y_i \cdot \text{log}(x_i) + (1 - y_i) \cdot \text{log}(1 - x_i)]$$

We will set $w_n = 1$ for now, so no need to worry about that. The formula for BCELoss becomes:

$$\text{BCELoss} = -[y_i \cdot \text{log}(x_i) + (1 - y_i) \cdot \text{log}(1 - x_i)]$$

Notice, the negative sign at the beginning. We will minimize this BCE loss, which is the same as maximizing the discriminator's loss.

Our discriminator's loss is:

$$\text{log}(D(\text{real\_img})) \; + \; \text{log}(1 - D(G(z)))$$

So, in the BCELoss formula, if we set $y_i = 1$ and $x_i = D(\text{real\_img})$, we get:

$$-[\text{log}(D(\text{real\_img}))]$$

We can get the above term by using `criterion(disc_real, torch.ones_like(disc_real))`
, where: 
- `disc_real` is $D(\text{real\_img})$
- `criterion` is the `BCELoss` function

This was the first term in the discriminator's loss. Now, before moving to the second term, there's one thing to note. We passed `torch.ones_like(disc_real)` instead of `torch.ones(disc_real)`. Why? Because, if we use `torch.ones`, we'll need to pass the device to it, ie, `torch.ones(disc_real.size(0), device=device)`. But, we can avoid this by using `torch.ones_like(disc_real)`, which will create a tensor of ones with the same shape as `disc_real`.


For the second term, if we set $y_i = 0$ and $x_i = D(G(z))$, we get:

$$-[\text{log}(1 - D(G(z)))]$$

Now, if we add these two terms, we get the discriminator's loss, ie,

$$-[\text{log}(D(\text{real\_img})) \; + \; \text{log}(1 - D(G(z)))]$$

We want to re-use the fake images generated by the generator, ie, `fake_img = gen(noise)` or mathematically, $G(z)$, but when we call `lossD.backward()`, the gradients are cleared from memory to save space. This means that we will need to re-generate the fake images, which is computationally expensive. We have 2 options to solve this problem:
1. We can detach the fake images from the computational graph by calling `disc_fake_preds = disc(fake_images.detach())`.
2. We can call `disc_loss.backward(retain_graph=True)` to save the gradients for the generator.

Both the options are equivalent, and going ahead with either of them is fine. But, we'll use both of them!

In [24]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    """
    Returns the loss of the discriminator.
    Parameters:
        - gen: the generator model, which returns an image given 
               z-dimensional noise
        - disc: the discriminator model, which returns a single-dimensional 
                prediction of real/fake
        - criterion: the loss function, which should be used to compare 
                     the discriminator's predictions to the ground truth 
                     reality of the images (e.g. fake = 0, real = 1)
        - real: a batch of real images
        - num_images: the number of images the generator should produce, 
                      which is also the length of the real images
        - z_dim: the dimension of the noise vector, a scalar
        - device: the device type (eg: cuda or cpu)
    Returns:
        disc_loss: a torch scalar loss value for the current batch

    The following is the mathematical formula for the discriminator loss:
        max(log(D(x)) + log(1 - D(G(z))))
    """
    
    # 1) Create a noise vector and generate a batch (ie, num_images) of fake images.
    noise_vector = torch.randn(num_images, z_dim).to(device)  # z
    fake_images = gen(noise_vector)                           # G(z)

    # 2) Get the discriminator's prediction of the fake image 
    #    and calculate the loss. Don't forget to detach the generator!
    #    (Remember the loss function you set earlier -- criterion. You need a 
    #    'ground truth' tensor in order to calculate the loss. 
    #    For example, a ground truth tensor for a fake image is all zeros.)
    disc_fake_preds = disc(fake_images.detach())                   # D(G(z))
    disc_fake_loss = criterion(disc_fake_preds, 
                               torch.zeros_like(disc_fake_preds))  # log(1 - D(G(z)))
    
    # 3) Get the discriminator's prediction of the real image and calculate the loss.
    disc_real_preds = disc(real)                                   # D(x)
    disc_real_loss = criterion(disc_real_preds, 
                               torch.ones_like(disc_real_preds))   # log(D(x))

    # 4) Calculate the discriminator's loss by averaging the real and fake loss
    #    and set it to disc_loss.
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    
    return disc_loss


Recall: 

The training of the generator was to ***minimize*** the following:

$$\text{log}(1 - D(G(z)))$$

However, this causes the vanishing gradient problem, which leads to slower training, and sometimes even no training. To solve this, we can use an equivalent form of the above, which is to **maximize** the following:

$$\text{log}(D(G(z)))$$

We will use the BCELoss for the generator in a similar fashion as we did for the discriminator.

In [26]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    """
    Returns the loss of the generator.
    Parameters:
        - gen: the generator model, which returns an image given 
               z-dimensional noise
        - disc: the discriminator model, which returns a single-dimensional 
                prediction of real/fake
        - criterion: the loss function, which should be used to compare 
                     the discriminator's predictions to the ground truth 
                     reality of the images (e.g. fake = 0, real = 1)
        - num_images: the number of images the generator should produce, 
                      which is also the length of the real images
        - z_dim: the dimension of the noise vector, a scalar
        - device: the device type (eg: cuda or cpu)
    Returns:
        gen_loss: a torch scalar loss value for the current batch

    The following is the mathematical formula for the generator loss:
        max(log(D(G(z))))
    """

    # 1) Create noise vectors and generate a batch of fake images.
    noise_vector = torch.randn(num_images, z_dim).to(device)  # z
    fake_images = gen(noise_vector)                           # G(z)
    
    # 2) Get the discriminator's prediction of the fake image.
    disc_fake_preds = disc(fake_images)                       # D(G(z))
    
    # 3) Calculate the generator's loss. Remember the generator wants
    #    the discriminator to think that its fake images are real
    gen_loss = criterion(disc_fake_preds, 
                         torch.ones_like(disc_fake_preds))    # log(D(G(z)))

    return gen_loss


**A note about GAN training**

The training of GANs is a bit tricky. The discriminator and the generator are trained alternately. The discriminator is trained first, and then the generator is trained. This process is repeated for a fixed number of iterations. The discriminator is trained to catch the generator when it generates fake images, and the generator is trained to fool the discriminator into thinking that the fake images are real.

For each epoch, we will process the entire dataset in batches. For every batch, we will need to update the discriminator and generator using their loss. Batches are sets of images that will be predicted on before the loss functions are calculated (instead of calculating the loss function after each image). Note that you may see a loss to be greater than 1, this is okay since binary cross entropy loss can be any positive number for a sufficiently confident wrong guess.

It’s also often the case that the discriminator will outperform the generator, especially at the start, because its job is easier. It's important that neither one gets too good (that is, near-perfect accuracy), which would cause the entire model to stop learning. Balancing the two models is actually remarkably hard to do in a standard GAN.

In [22]:
def fn():
    noise = torch.randn(batch_size, z_dim).to(device)  # z
    fake_img = gen(noise)
    disc_fake = disc(fake_img).view(-1)
    print(torch.zeros_like(disc_fake))
    print(torch.zeros(disc_fake.size(0), device=device))

fn()

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])


In [None]:
current_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
results = {
    "gen_train_loss": [],
    "disc_train_loss": [],
}

for epoch in tqdm(range(num_epochs)):
    
    # we iterate over the training dataloader
    # we only need the images, and not the labels
    for real_img, _ in train_loader:
        
        curr_batch_size = len(real_img)
        # Flatten the batch of real images
        real_img = real_img.view(curr_batch_size, -1).to(device)

        # Update discriminator (Notice that we first train the discriminator)
        # Zero out the gradients before backpropagation
        optim_disc.zero_grad()
        # Calculate the discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real_img, curr_batch_size, z_dim, device)
        # Update gradients
        disc_loss.backward(retain_graph=True)  # we need to re-use the gradients for the generator
        # Update optimizer
        optim_disc.step()

        # Update generator
        # Zero out the gradients before backpropagation
        optim_gen.zero_grad()
        # Calculate the generator loss
        gen_loss = get_gen_loss(gen, disc, criterion, curr_batch_size, z_dim, device)
        # Update gradients
        gen_loss.backward()   # we have re-used the gradients for the generator, so no need to save the gradients
        # Update optimizer
        optim_gen.step()

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item()
        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item()

        # Visualization code
        if current_step % display_step == 0 and current_step > 0:
            mean_discriminator_loss = mean_discriminator_loss / display_step
            mean_generator_loss = mean_generator_loss / display_step
            results["gen_train_loss"].append(mean_generator_loss)
            results["disc_train_loss"].append(mean_discriminator_loss)
            print(f"Step {current_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
            fake_noise = torch.randn(curr_batch_size, z_dim).to(device)
            fake_img = gen(fake_noise)
            plot_images(fake_img)
            plot_images(real_img)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        current_step += 1
        

In [None]:
plot_results(results)