## Generartive Adversarial Network
In this notebook, I will be working on a generative adversarial network. There are two main components in a network like this:
1. Generator : A generative neural network that attempts to generate synthetic fake data from the training set.
2. Discriminator : Another network that attempts to tell the difference between fake and real data from the generator and training set.

The way this network learns is like a back and forth game, where they keep trying to trick each other.

Here are some of the resources used for this study:
1. https://www.youtube.com/watch?v=_pIMdDWK5sc
2. https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html
3. https://developers.google.com/machine-learning/gan/

In [1]:
# Let's start by importing all the necessry modules
import os
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as func
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

In [2]:
# Setting constants
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)

This is the MNIST Dataset Loader module, this has been taken from https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

(Purely because I didn't want to spend too much time writing this, since it's readily available lol)

In [8]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = PATH_DATASETS,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

## Generator
Now, let's build the Generator. This is going to be a Convolutional Neural Network (CNN) that attempts to generate fake images based on the MNIST set. 

The general structure of this network is going to be 

Input - > Linear Layer - > Upsample Layer 1 - > Upsample Layer 2 - > Convolution to 28 x 28

*A note on the dimensions of the convolution layers:*

At the beginning of the generation process, the feature maps need a high number of channels (e.g., 64) to capture diverse patterns and encode the necessary complexity of the image. At this stage, the spatial dimensions are small (7×7), so there's room for more depth in feature representation.


As the spatial resolution increases (14×14), then (28×28), the feature maps focus more on refining the image's structure and texture rather than adding new abstract details. Thus, fewer channels are needed.

The dimensions of the outputs from the transposed convolution (upsamping) layers can be calculated using the following formula:

$$
H_{\text{out}} = (H_{\text{in}} - 1) \cdot \text{stride} - 2 \cdot \text{padding} + \text{kernel\ size}
$$
* *H here is Height, which is equal to width in our case* 

The same for the convolution layer can be calculated using:

$$
H_{\text{out}} = \frac{H_{\text{in}} + 2 \cdot \text{padding} - \text{kernel\ size}}{\text{stride}} + 1
$$


In [9]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.linear = nn.Linear(latent_dim, 7*7*64) 
        self.convTrans1 = nn.ConvTranspose2d(64, 32, 4, stride=2)
        self.convTrans2 = nn.ConvTranspose2d(32, 16, 4, stride=2)
        self.conv = nn.Conv2d(16, 1, kernel_size=7)

    def foward(self, x):
        # Reshaping the latent space
        x = self.linear(x)
        x = func.relu(x)
        x = x.view(-1, 64, 7, 7)
        
        # Upsampling to 16x16
        x = self.convTrans1(x)
        x = func.relu(x)
        
        # Upsampling to 34x34
        x = self.convTrans2(x)
        x = func.relu(x)

        # Convoluting back to 28x28
        x = self.conv(x)

        return x

## Discriminator
Now that we have our generator, it's time to build our Discriminator. If the generator is the criminal, this guy is our detective to tell apart from the real and fake images.

When we get to programming the loss functions for both of these networks, the discriminator actually plays a very important role in backpropogation for both networks (Making the generator make better fake images, and teach the discriminator to judge better). But, we will get to that at the training stage.

This is also a convolutional neural network, a more traditional 2-layer one to classify whether the image is fake or not (will output 0 or 1). 

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = func.relu(func.max_pool2d(self.conv1(x), 2))
        x = func.relu(func.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = func.relu(self.fc1(x))
        x = func.dropout(x, training=self.training)
        x = self.fc2(x)
        x = torch.sigmoid(x)

        return x