# **Generative Adversarial Networks(GANs)**
<img align='right' width='800' src="https://cdn-images-1.medium.com/v2/resize:fit:851/0*pPEL7ryJR51VpnDO.jpg">

In [18]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [2]:
# Visioalize the data
def show(tensor, ch=1, size=(28, 28), num_to_display=16):
    """
    Inputs would be tensor with (batch_size, channel, height, weight) dimention
    First we detach() tensor so because it's not require grade any more,
    Then send it to cpu() to make sure the tensor doesn't on different device
    Matplotlib show images in (height, width, channel) dimention so the images permute to match the criteria
    """
    images = tensor.detach().cpu().view(-1, ch, *size)
    grid = make_grid(images[:num_to_display], nrow=4).permute(1, 2, 0)
    plt.axis(False)
    plt.imshow(grid)
    plt.show()

In [45]:
def get_data(data=MNIST, bs=128):
    """
    From torchvision.datasets we can get different dataset
    for training GANs we don't need test datasets just trian sets will be enough
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))
    ])
    train_set = MNIST('.',
                      train=True,
                      transform=ToTensor(),
                      download=True)

    # group the data in different batch size
    data_loader = DataLoader(train_set, bs, shuffle=True)

    return data_loader

In [55]:
# Hyperparameters
EPOCH = 50
Z_DIM = 100
LR = 1e-4
BS = 128

data = get_data(MNIST, BS)
C, H, W = next(iter(data))[0][0].shape
loss_func = nn.BCELoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Availabe device is: ", device)

Availabe device is:  cuda


## **Discriminator and Generator Networkrs**

In [56]:
#Generator
class Generator(nn.Module):
    def __init__(self, z_dim=100, hidden_dim=128, out_dim=28*28):
        super().__init__()
        self.gen = nn.Sequential(
            self._gen_block(z_dim, hidden_dim),
            self._gen_block(hidden_dim, hidden_dim*2),
            self._gen_block(hidden_dim*2, hidden_dim*4),
            self._gen_block(hidden_dim*4, hidden_dim*8),
            self._gen_block(hidden_dim*8, out_dim),
            nn.Tanh()
        )

    def forward(self, nois):
        return self.gen(nois)

    def _gen_block(self, in_dim, out):
        return nn.Sequential(
            nn.Linear(in_dim, out),
            nn.BatchNorm1d(out),
            nn.ReLU(inplace=True)
        )

#Discriminator
class Discriminator(nn.Module):
    def __init__(self, in_dim=28*28, hidden_dim=128, out_dim=1):
        super().__init__()
        self.disc = nn.Sequential(
            self._disc_block(in_dim, hidden_dim*4),
            self._disc_block(hidden_dim*4, hidden_dim*2),
            self._disc_block(hidden_dim*2, hidden_dim),
            nn.Linear(hidden_dim, out_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 1*28*28) # Flatten input data to a Linear layer
        return self.disc(x)

    def _disc_block(self, in_dim, out):
        return nn.Sequential(
            nn.Linear(in_dim, out),
            nn.BatchNorm1d(out),
            nn.LeakyReLU(0.2)
        )

In [57]:
gen_noise = lambda number, z_dim: torch.randn(number, z_dim).to(device)

gen = Generator().to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=LR)
#If we need to change learning rate during training we can use lr_scheduler function in torch.optim
gen_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(gen_opt, step_size=10, gamma=0.5)

disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=LR)
disc_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(disc_opt, step_size=10, gamma=0.5)

In [None]:
#Check the data
x, y = next(iter(data))
print("Shape of dataset images: ", x.shape)
print("label of images", y[:16])

noise = gen_noise(BS, Z_DIM)
fake = gen(noise)
show(x, ch=C, size=(H, W))
show(fake)

## **BCELoss**
$$
    \large -\frac{1}{n}\sum_{i=1}^{n}{-[y_i\log(𝚢^{̂}_i) + (1-y_i)\log(1-y^{̂}_i)]}
$$

<br>

**Generator Loss**\
<br>
Generator wants to fool the Discriminator so in loss calculation we consider **one** matrix to calculate the loss so the **BCE** formular will be:

$$
    \frac{1}{n}\sum_{i=1}^{n}{(\log(𝚢^{̂}_i))}
$$

where the $y^̂$ is the last output of the Discriminator:

$$
   \large 𝒁 ⟶ Generator \xrightarrow[\text{}]{\text{G(z)}} Discriminator \xrightarrow[\text{}]{\text{D(G(z))}} -\frac{1}{n}\sum_{i=1}^{n}{(\log(D(G(z))))}
$$

**Discriminator Loss**
<br>
**one** matrix to calculate the loss for **real** input and **zeors** matrix to calculate the loss for **fake** input which is the output of generator so the **BCE** formular will be:

<br>

$\text{When label is 1 (for real ones)}:$
$$
    \frac{1}{n}\sum_{i=1}^{n}{(\log(𝚢^{̂}_i))}
$$

<br>

$\text{When label is 0 (for generated images)}:$
$$
    \frac{1}{n}\sum_{i=1}^{n}{(\log(1- 𝚢^{̂}_i))}
$$

where the $y^̂$ is the last output of the Discriminator:

$$
    \large -\frac{1}{n}\sum_{i=1}^{n}{-[\log(D(x)) + \log(1 - D(G(z)))]}
$$


In [59]:
#This the the most important part of the training
def gen_loss_func(gen_net, disc_net, loss_func, num, z_dim):
    noise = gen_noise(num, z_dim)
    fake = gen_net(noise)
    pred = disc_net(fake)
    real = torch.ones_like(pred)
    return loss_func(pred, real)


def disc_loss_func(gen_net, disc_net, loss_func, image, num, z_dim):
    noise = gen_noise(num, z_dim)
    fake = gen_net(noise)
    fake_pred = disc_net(fake.detach())  # detach() the generator output so it won't participate in gen_net learning
    real_pred = disc_net(image)

    loss_real = loss_func(real_pred, torch.ones_like(real_pred))
    loss_fake = loss_func(fake_pred, torch.zeros_like(fake_pred))

    return (loss_real + loss_fake) / 2

In [60]:
from torch.utils.tensorboard import SummaryWriter
!rm -r /content/runs
writer = SummaryWriter("/content/runs")
writer_fake = SummaryWriter("/content/runs/fake")
writer_real = SummaryWriter("/content/runs/real")

In [None]:
%load_ext tensorboard
%tensorboard --logdir=runs
# %reload_ext tensorboard

In [None]:
step = 0
for epoch in range(EPOCH):
    mean_gen_loss = 0
    mean_disc_loss = 0
    print(f"\nEpoch: {epoch + 1}")

    for batch, (real, _) in enumerate(tqdm(data)):
        real = real.to(device)

        disc_loss = disc_loss_func(gen, disc, loss_func, real, BS, Z_DIM)
        disc_opt.zero_grad()
        disc_loss.backward(retain_graph=True) # If False, the graph used to compute the grad will be freed, Actually It isnt necessary
        disc_opt.step()

        gen_loss = gen_loss_func(gen, disc, loss_func, BS, Z_DIM)
        gen_opt.zero_grad()
        gen_loss.backward()
        gen_opt.step()

        if batch % 150 == 0 and batch != 0:
            # print(f'  batch: {batch} -- {batch*32} of {len(data.dataset)} sample')
            with torch.no_grad():
                step += 1
                fake = gen(gen_noise(BS, Z_DIM)).view(-1, C, H, W)
                image = real.view(-1, C, H, W)
                real_grid = make_grid(image[:24], normalize=True)
                fake_grid = make_grid(fake[:24], normalize=True)

                writer_fake.add_image(
                    "MNIST fake image", fake_grid, global_step=step
                )
                writer_real.add_image(
                    "MNIST real image", real_grid, global_step=step
                )


    disc_loss /= len(data)
    gen_loss /= len(data)
    print(f'  Discriminator Loss: {disc_loss:.4f} -- Generator Loss: {gen_loss:.4f}')
    gen_exp_lr_scheduler.step()
    disc_exp_lr_scheduler.step()

    writer.add_scalar("Loss/Disc", disc_loss, epoch)
    writer.add_scalar("Loss/Gen", gen_loss, epoch)
    if epoch % 10 == 0 and epoch > 0:
        print(f"  >>> Discriminator Learning Rate: {disc_opt.param_groups[0]['lr']}")
        print(f"  >>> Generator Learning Rate: {gen_opt.param_groups[0]['lr']}")
        # show(gen(gen_noise((BS, Z_DIM))))