# **Generative Adversarial Networks(GANs)**
<img align='right' width='800' src="https://cdn-images-1.medium.com/v2/resize:fit:851/0*pPEL7ryJR51VpnDO.jpg">

In [18]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [2]:
# Visioalize the data
def show(tensor, ch=1, size=(28, 28), num_to_display=16):
    """
    Inputs would be tensors with (batch_size, channel, height, weight) dimensions.
    First, we detach() the tensor, so it doesn't require grade any more.
    Then send it to cpu() to make sure the tensor isn't on a different device.
    Matplotlib shows images in (height, width, channel) dimention, so the images permute to match the criteria.
    """
    images = tensor.detach().cpu().view(-1, ch, *size)
    grid = make_grid(images[:num_to_display], nrow=4).permute(1, 2, 0)
    plt.axis(False)
    plt.imshow(grid)
    plt.show()

In [45]:
def get_data(data=MNIST, bs=128):
    """
    From torchvision.datasets we can get different datasets.
    For training GANs, we don't need test datasets; just train sets will be enough.
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))
    ])
    train_set = MNIST('.',
                      train=True,
                      transform=ToTensor(),
                      download=True)

    # group the data in different batch size
    data_loader = DataLoader(train_set, bs, shuffle=True)

    return data_loader

In [55]:
# Hyperparameters
EPOCH = 50
Z_DIM = 100
LR = 1e-4
BS = 128

data = get_data(MNIST, BS)
C, H, W = next(iter(data))[0][0].shape
loss_func = nn.BCELoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Availabe device is: ", device)

Availabe device is:  cuda


## **Discriminator and Generator Networkrs**

In [56]:
# Generator
class Generator(nn.Module):
    """
    A class representing the generator component of a Generative Adversarial Network (GAN).

    Parameters:
        z_dim (int): The dimension of the input noise vector.
        hidden_dim (int): The dimension of the hidden layers in the generator.
        out_dim (int): The dimension of the output data (flattened).

    Methods:
        forward(nois): Forward pass function of the generator.
        _gen_block(in_dim, out): Helper function to create a generator block.

    Attributes:
        gen (nn.Sequential): The sequential model representing the generator architecture.
    """
    def __init__(self, z_dim=100, hidden_dim=128, out_dim=28*28):
        super().__init__()
        self.gen = nn.Sequential(
            self._gen_block(z_dim, hidden_dim),
            self._gen_block(hidden_dim, hidden_dim*2),
            self._gen_block(hidden_dim*2, hidden_dim*4),
            self._gen_block(hidden_dim*4, hidden_dim*8),
            self._gen_block(hidden_dim*8, out_dim),
            nn.Tanh()
        )

    def forward(self, noise):
        return self.gen(noise)

    def _gen_block(self, in_dim, out):
        return nn.Sequential(
            nn.Linear(in_dim, out),
            nn.BatchNorm1d(out),
            nn.ReLU(inplace=True)
        )

# Discriminator
class Discriminator(nn.Module):
    """
    A class representing the discriminator component of a Generative Adversarial Network (GAN).

    Parameters:
        in_dim (int): The dimension of the input data.
        hidden_dim (int): The dimension of the hidden layers in the discriminator.
        out_dim (int): The dimension of the output data (single value for binary classification).

    Methods:
        forward(x): Forward pass function of the discriminator.
        _disc_block(in_dim, out): Helper function to create a discriminator block.

    Attributes:
        disc (nn.Sequential): The sequential model representing the discriminator architecture.
    """
    def __init__(self, in_dim=28*28, hidden_dim=128, out_dim=1):
        super().__init__()
        self.disc = nn.Sequential(
            self._disc_block(in_dim, hidden_dim*4),
            self._disc_block(hidden_dim*4, hidden_dim*2),
            self._disc_block(hidden_dim*2, hidden_dim),
            nn.Linear(hidden_dim, out_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 1*28*28) # Flatten input data to a Linear layer
        return self.disc(x)

    def _disc_block(self, in_dim, out):
        return nn.Sequential(
            nn.Linear(in_dim, out),
            nn.BatchNorm1d(out),
            nn.LeakyReLU(0.2)
        )

In [None]:
# Definition of a lambda function to generate random noise samples for a GAN
gen_noise = lambda number, z_dim: torch.randn(number, z_dim).to(device)

# Description:
# This lambda function is used to create random noise samples, commonly employed in training Generative Adversarial Networks (GANs).
# It takes two parameters:
#   - number: An integer indicating the number of noise samples to generate.
#   - z_dim: An integer representing the dimensionality of each noise sample.
# The function returns a PyTorch tensor of shape (number, z_dim) containing random noise samples drawn from a standard normal distribution with mean 0 and standard deviation 1.
# Additionally, the .to(device) method at the end of the function ensures that the generated tensor is moved to the specified device for computation, assuming 'device' is defined elsewhere in the code. This is typically done to utilize GPU acceleration for faster processing.


In [57]:
gen_noise = lambda number, z_dim: torch.randn(number, z_dim).to(device)

gen = Generator().to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=LR)
#If we need to change learning rate during training we can use lr_scheduler function in torch.optim
gen_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(gen_opt, step_size=10, gamma=0.5)

disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=LR)
disc_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(disc_opt, step_size=10, gamma=0.5)

In [None]:
#Check the data
x, y = next(iter(data))
print("Shape of dataset images: ", x.shape)
print("label of images", y[:16])

noise = gen_noise(BS, Z_DIM)
fake = gen(noise)
show(x, ch=C, size=(H, W))
show(fake)

## **BCELoss**
$$
    \large -\frac{1}{n}\sum_{i=1}^{n}{-[y_i\log(𝚢^{̂}_i) + (1-y_i)\log(1-y^{̂}_i)]}
$$

<br>

**Generator Loss**\
<br>
Generator wants to fool the Discriminator so in loss calculation we consider **one** matrix to calculate the loss so the **BCE** formular will be:

$$
    \frac{1}{n}\sum_{i=1}^{n}{(\log(𝚢^{̂}_i))}
$$

where the $y^̂$ is the last output of the Discriminator:

$$
   \large 𝒁 ⟶ Generator \xrightarrow[\text{}]{\text{G(z)}} Discriminator \xrightarrow[\text{}]{\text{D(G(z))}} -\frac{1}{n}\sum_{i=1}^{n}{(\log(D(G(z))))}
$$

**Discriminator Loss**
<br>
**one** matrix to calculate the loss for **real** input and **zeors** matrix to calculate the loss for **fake** input which is the output of generator so the **BCE** formular will be:

<br>

$\text{When label is 1 (for real ones)}:$
$$
    \frac{1}{n}\sum_{i=1}^{n}{(\log(𝚢^{̂}_i))}
$$

<br>

$\text{When label is 0 (for generated images)}:$
$$
    \frac{1}{n}\sum_{i=1}^{n}{(\log(1- 𝚢^{̂}_i))}
$$

where the $y^̂$ is the last output of the Discriminator:

$$
    \large -\frac{1}{n}\sum_{i=1}^{n}{-[\log(D(x)) + \log(1 - D(G(z)))]}
$$


In [59]:
#This the the most important part of the training
def gen_loss_func(gen_net, disc_net, loss_func, num, z_dim):
    """
    Calculate the generator's loss based on the provided discriminator and loss function.

    Parameters:
        gen_net (nn.Module): The generator network.
        disc_net (nn.Module): The discriminator network.
        loss_func: The loss function used for calculating the loss.
        num (int): The number of noise samples to generate.
        z_dim (int): The dimensionality of the noise vector.

    Returns:
        torch.Tensor: The calculated generator loss.

    """
    # Generate noise samples
    noise = gen_noise(num, z_dim)
    
    # Generate fake images using the generator
    fake = gen_net(noise)
    
    # Pass fake images through the discriminator
    pred = disc_net(fake)
    
    # Create target labels (all ones for the generator)
    real = torch.ones_like(pred)
    
    # Calculate generator loss using the provided loss function
    return loss_func(pred, real)



def disc_loss_func(gen_net, disc_net, loss_func, image, num, z_dim):
    """
    Calculate the discriminator's loss based on the provided generator, discriminator, image, and loss function.

    Parameters:
        gen_net (nn.Module): The generator network.
        disc_net (nn.Module): The discriminator network.
        loss_func: The loss function used for calculating the loss.
        image (torch.Tensor): Real images used for training the discriminator.
        num (int): The number of noise samples to generate.
        z_dim (int): The dimensionality of the noise vector.

    Returns:
        torch.Tensor: The calculated discriminator loss.

    """
    # Generate noise samples
    noise = gen_noise(num, z_dim)
    
    # Generate fake images using the generator
    fake = gen_net(noise)
    
    # Pass fake images (with detached gradients) and real images through the discriminator
    fake_pred = disc_net(fake.detach())  # detach() the generator output to prevent gradient flow
    real_pred = disc_net(image)
    
    # Compute losses for real and fake images
    loss_real = loss_func(real_pred, torch.ones_like(real_pred))  # real images should be classified as real (1)
    loss_fake = loss_func(fake_pred, torch.zeros_like(fake_pred))  # fake images should be classified as fake (0)
    
    # Combine real and fake losses and average them
    return (loss_real + loss_fake) / 2


In [60]:
from torch.utils.tensorboard import SummaryWriter
!rm -r /content/runs
writer = SummaryWriter("/content/runs")
writer_fake = SummaryWriter("/content/runs/fake")
writer_real = SummaryWriter("/content/runs/real")

In [None]:
%load_ext tensorboard
%tensorboard --logdir=runs
# %reload_ext tensorboard

In [None]:
step = 0
for epoch in range(EPOCH):
    mean_gen_loss = 0
    mean_disc_loss = 0
    print(f"\nEpoch: {epoch + 1}")

    for batch, (real, _) in enumerate(tqdm(data)):
        real = real.to(device)

        # Train the discriminator
        disc_loss = disc_loss_func(gen, disc, loss_func, real, BS, Z_DIM)
        disc_opt.zero_grad()
        disc_loss.backward(retain_graph=True)  # Retain graph for the generator
        disc_opt.step()

        # Train the generator
        gen_loss = gen_loss_func(gen, disc, loss_func, BS, Z_DIM)
        gen_opt.zero_grad()
        gen_loss.backward()
        gen_opt.step()

        if batch % 150 == 0 and batch != 0:
            # Generate fake images and visualize
            with torch.no_grad():
                step += 1
                fake = gen(gen_noise(BS, Z_DIM)).view(-1, C, H, W)
                image = real.view(-1, C, H, W)
                real_grid = make_grid(image[:24], normalize=True)
                fake_grid = make_grid(fake[:24], normalize=True)

                writer_fake.add_image(
                    "MNIST fake image", fake_grid, global_step=step
                )
                writer_real.add_image(
                    "MNIST real image", real_grid, global_step=step
                )

    # Compute mean losses for discriminator and generator
    disc_loss /= len(data)
    gen_loss /= len(data)
    print(f'  Discriminator Loss: {disc_loss:.4f} -- Generator Loss: {gen_loss:.4f}')

    # Adjust learning rates using schedulers
    gen_exp_lr_scheduler.step()
    disc_exp_lr_scheduler.step()

    # Write losses to tensorboard
    writer.add_scalar("Loss/Disc", disc_loss, epoch)
    writer.add_scalar("Loss/Gen", gen_loss, epoch)

    # Print learning rates every 10 epochs
    if epoch % 10 == 0 and epoch > 0:
        print(f"  >>> Discriminator Learning Rate: {disc_opt.param_groups[0]['lr']}")
        print(f"  >>> Generator Learning Rate: {gen_opt.param_groups[0]['lr']}")
