## Generative Adversarial Network (GAN)

Paper: [Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661)

Helpful Resources:
- [Aladdin Persson's playlist on GANs](https://youtube.com/playlist?list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&si=8ooImkbbXhCUC1xB)
- [GANs specialization on coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans)
- [Stanford's Deep Generative Models playlist](https://youtube.com/playlist?list=PLoROMvodv4rPOWA-omMM6STXaWW4FvJT8&si=N_TpTe1bPIhte-t8)
- [AssemblyAI's GAN tutorial](https://youtu.be/_pIMdDWK5sc?si=Mtx2oWh1ZO9tqWYg)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

from torch.utils.tensorboard import SummaryWriter

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

print("Imports done!")

Imports done!


In [4]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        """
        - param img_dim: dimension of image (eg: 28x28x1 = 784 for 
            grayscale MNIST images)
        """
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid()   # output between 0 and 1
        )
    
    def forward(self, x):
        return self.disc(x)
    

In [5]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        """
        - param z_dim: dimension of latent noise vector
        - param img_dim: dimension of image (eg: 28x28x1 = 784 for 
            grayscale MNIST images)
        """
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh()   # output between -1 and 1
        )

    def forward(self, x):
        return self.gen(x)
    

In [None]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4   # Karpathy constant
z_dim = 64
img_dim = 28*28*1   # 1 means grayscale image
batch_size = 32
num_epochs = 50

In [None]:
transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# we can use the actual mean and std of the MNIST dataset, ie,
# transforms.Normalize((0.1307,), (0.3081,))

train_dataset = MNIST(root="dataset/", transform=transformations, download=True, train=True)
test_dataset = MNIST(root="dataset/", transform=transformations, download=True, train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def fn():
    for item in train_loader:
        print(len(item))
        print(item[0].shape, item[1].shape)
        break

fn()

In [None]:
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

# fixed_noise is the latent noise vector
# torch.randn generates random numbers from a normal distribution
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

# separate optimizers for generator and discriminator
optim_disc = optim.Adam(disc.parameters(), lr=lr)
optim_gen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()  # binary cross entropy loss

# for tensorboard
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

Recall: 

The training of the discriminator was to ***maximize*** the following:

$$\text{log}(D(\text{real\_img})) \; + \; \text{log}(1 - D(G(z)))$$

, where:

- $D$ is the discriminator
- $G$ is the generator
- $z$ is the latent noise vector
- $\text{real\_img}$ is the real image
- $G(z)$ is the image generated by the generator

#### A note about using BCELoss. 

The formula for BCELoss in PyTorch is:

$$\text{BCELoss} = -w_n [y_i \cdot \text{log}(x_i) + (1 - y_i) \cdot \text{log}(1 - x_i)]$$

We will set $w_n = 1$ for now, so no need to worry about that. The formula for BCELoss becomes:

$$\text{BCELoss} = -[y_i \cdot \text{log}(x_i) + (1 - y_i) \cdot \text{log}(1 - x_i)]$$

Notice, the negative sign at the beginning. We will minimize this BCE loss, which is the same as maximizing the discriminator's loss.

Our discriminator's loss is:

$$\text{log}(D(\text{real\_img})) \; + \; \text{log}(1 - D(G(z)))$$

So, in the BCELoss formula, if we set $y_i = 1$ and $x_i = D(\text{real\_img})$, we get:

$$-[\text{log}(D(\text{real\_img}))]$$

This was the first term in the discriminator's loss. For the second term, if we set $y_i = 0$ and $x_i = D(G(z))$, we get:

$$-[\text{log}(1 - D(G(z)))]$$

Now, if we add these two terms, we get the discriminator's loss, ie,

$$-[\text{log}(D(\text{real\_img})) \; + \; \text{log}(1 - D(G(z)))]$$

We want to re-use the fake images generated by the generator, ie, `fake_img = gen(noise)` or mathematically, $G(z)$, but when we call `lossD.backward()`, the gradients are cleared from memory to save space. This means that we will need to re-generate the fake images, which is computationally expensive. We have 2 options to solve this problem:
1. We can detach the fake images from the computational graph by calling `disc_fake = disc(fake_img.detach()).view(-1)`.
2. We can call `lossD.backward(retain_graph=True)` to save the gradients for the generator.

Both the options are equivalent, and going ahead with either of them is fine.

Recall: 

The training of the generator was to ***minimize*** the following:

$$\text{log}(1 - D(G(z)))$$

However, this causes the vanishing gradient problem, which leads to slower training, and sometimes even no training. To solve this, we can use an equivalent form of the above, which is to **maximize** the following:

$$\text{log}(D(G(z)))$$

We will use the BCELoss for the generator in a similar fashion as we did for the discriminator.

In [None]:
for epoch in tqdm(range(num_epochs)):
    
    # we iterate over the training dataloader
    # we only need the images, and not the labels
    for batch_idx, (real_img, _) in enumerate(train_loader):
        
        # flatten the image to a 1D tensor, but keep the batch size
        real_img = real_img.view(-1, 784).to(device)
        # the first dimension of the tensor is the batch size
        batch_size = real_img.shape[0]

        # Discriminator training: max(log(D(x)) + log(1 - D(G(z))))
        noise = torch.randn(batch_size, z_dim).to(device)  # z
        fake_img = gen(noise)  # G(z)
        disc_real = disc(real_img).view(-1)   # D(x) --> .view(-1) flattens the tensor
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))  # log(D(x))
        disc_fake = disc(fake_img.detach()).view(-1)  # D(G(z))
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))  # log(1 - D(G(z)))
        lossD = (lossD_real + lossD_fake) / 2  # (log(D(x)) + log(1 - D(G(z)))) / 2
        
        disc.zero_grad()
        lossD.backward()
        optim_disc.step()

        # Generator training: max(log(D(G(z))))
        disc_fake2 = disc(fake_img).view(-1)  # D(G(z))
        lossG = criterion(disc_fake2, torch.ones_like(disc_fake2))  # log(D(G(z)))
        
        gen.zero_grad()
        lossG.backward()
        optim_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1
