# Generative Adversial Networks (GANs)

As the name suggests these are generative models (can produce novel samples) that estimate/model the density of distribution implicitly i,e. no loss function in terms of probabilistic density.

 # Basic Idea

Compare the real and synthetic images to a point where the synthetic images look real. The Generator clashes with the Discriminator, hence the word adverserial is used.

<p align="center">
<img src="./doc_imgs/GAN_architecture.png" width="" height="300" />
</p>

* The generator tries to generate novel (real looking) samples from noise
* The Disc. distuinguishes between real and fake images and forces the Gen. to produce successively more realistic samples

# Loss Functions

**Discriminator Loss:**

$$
J^{(D)} = -(1/2)E_{x\sim p_{data}}logD(x) - (1/2)E_{z \sim p_{z}(z)}log(1-D(G(z)))\\
where \ J^{(D)}\ is\ simply\ binary\ cross\ entropy\ loss
$$

**Generator Loss:**

$$
J^{(G)} = -J^{(D)}\\
D \ provides \ supervision \ for\ the\ gradient \ of G\\
we\ can\ also\ say\ that\ D\ is\ the\ learnable\ loss\ fn\\ 

\\[0.2in]
\\ \textrm{in  practice the following loss function is used for the generator: }\\
J^{(G)} = - (1/2)E_{z}log(D(G(z)))
$$

**Minimax Game:**

D tries to maximize the probability it correctly classifies reals and fakes and G tries to minimize the probability that D will predict its outputs are fake (notice the 1-x and x terms in the loss)


# Implementation

In [15]:
# Essential Imports
import torchmetrics
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch import optim

import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar


import einops
import matplotlib.pyplot as plt


pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False


# The Generator
class G(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.upsample = nn.Sequential(
            nn.LazyLinear(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.LazyLinear(256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),


            nn.LazyLinear(512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.LazyLinear(1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),

            nn.LazyLinear(784),
            # nn.Tanh()
            nn.ReLU()
            
        )

    def forward(self, x):
        x = self.upsample(x)
        x = einops.rearrange(x, 'n (c h w) -> n c h w', h = 28, w = 28)
        return x

# The Discriminator
class D(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.downsample = nn.Sequential(
            
            nn.LazyLinear(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.LazyLinear(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.LazyLinear(1),
            nn.Sigmoid()
        
        )

    def forward(self, x):
        x = einops.rearrange(x, 'n c h w -> n (c h w)') #flatten imgs
        return self.downsample(x)


# The GAN class

class GAN(pl.LightningModule):
    def __init__(self,
        latent_dim = 32,
        lr = 0.0002 # as described in DCGAN Paper
    ) -> None:
        super().__init__()

        self.lr = lr
        self.gen = G()
        self.disc = D()
        self.latent_dim = latent_dim
        self.val_z = torch.randn(8, self.latent_dim) # Fix this for deteministic outputs

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                # transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

    def forward(self, z):
        # Have to implement this method because of pl
        return self.gen(z)

    def bce_loss(self, y_pred, y):
        x = F.binary_cross_entropy(y_pred, y)
        return x

    def configure_optimizers(self):
        

        opt_g = torch.optim.Adam(self.gen.parameters(), lr= self.lr, betas= (0.5, 0.99))  # as described in DCGAN Paper
        opt_d = torch.optim.Adam(self.disc.parameters(), lr= self.lr, betas= (0.5, 0.99)) # as described in DCGAN Paper
        return [opt_g, opt_d], []


    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # we sample a new z from a normal dist at each training step
        z = torch.randn(imgs.size(0), self.latent_dim).cuda()
        
        # Generator loss:  
        # Recall: J(G) = - (1/2)log(D(G(z)))
        if optimizer_idx == 0:
            fake_imgs = self(z)
            D_G_z = self.disc(fake_imgs)
            labels = torch.ones(fake_imgs.size(0), 1).cuda()
            g_loss = self.bce_loss(D_G_z, labels)
            self.log("g_loss", g_loss, prog_bar=True)
            return g_loss
            

        # Disc loss
        # Recall: J^{(D)} = -(1/2)E_{x\sim p_{data}}logD(x) - (1/2)E_{z \sim p_{z}(z)}log(1-D(G(z)))\\
        if optimizer_idx == 1:
            
            labels = torch.ones(imgs.size(0), 1).cuda()
            real_loss = 0.5*(self.bce_loss(self.disc(imgs), labels))

            labels = torch.zeros(imgs.size(0), 1).cuda()
            fake_loss = 0.5*(self.bce_loss(self.disc(self.gen(z).detach()), labels))

            d_loss = real_loss + fake_loss

            self.log("d_loss", d_loss, prog_bar=True)
            return d_loss


    def on_train_epoch_end(self):
            z = self.val_z.cuda()

            # log sampled images
            sample_imgs = self(z)
            grid = torchvision.utils.make_grid(sample_imgs.detach().cpu(), nrow=2, normalize=True, range=(-1, 1))
            # plt.imshow(sample_imgs[0][0].detach().cpu().numpy())
            # plt.show()
            self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
            
    def train_dataloader(self):
        dataset = datasets.MNIST('/home/ibrahim/Projects/Datasets/', train= True, download= True, transform= self.transform)
        train_loader  = DataLoader(dataset=dataset, batch_size= 32, shuffle= True, drop_last= True, num_workers= 12)

        return train_loader

    def val_dataloader(self):
        val_dataset = datasets.MNIST('/home/ibrahim/Projects/Datasets/', train= False, download= True, transform= self.transform)
        val_loader  = DataLoader(dataset=val_dataset, batch_size= 16, shuffle= True, drop_last= True, num_workers= 12)

        return val_loader


model = GAN()

trainer = pl.Trainer(
    gpus=1,
    max_epochs=100,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    enable_checkpointing= False,
    logger= True
)

trainer.fit(model)






GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


**Some GAN Hacks (take with a grain of salt)**
* Normalize inputs between -1 and 1
* use Tanh as the last layer of the generator output (because of above)
* Sample $z$ from a Gaussian dist
* SGD for $D$
* ADAM for $G$
* Use BatchNorm between layers
* One sided label smoothing: 
$$
J^{(D)} = -(1/2)(\lambda)E_{x\sim p_{data}}logD(x) - (1/2)E_{z}log(1-D(G(z)))\\
where\ \lambda \ is\ a\ small\ value\ <\ 1\ e.g\ 0.9  
$$