# **Generative Adversarial Networks(GANs)**
<img align='right' width='800' src="https://cdn-images-1.medium.com/v2/resize:fit:851/0*pPEL7ryJR51VpnDO.jpg">

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize

In [None]:
# Hyperparameters
EPOCH = 15
Z_DIM = 100
LR = 2e-4
BS = 128
C, H, W =1, 32, 32

loss_func = nn.BCELoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Availabe device is: ", device)

Availabe device is:  cuda


In [None]:
# Visioalize the data
def show(tensor, ch=1, size=(28, 28), num_to_display=16):
    """
    Inputs would be tensors with (batch_size, channel, height, weight) dimensions.
    First, we detach() the tensor, so it doesn't require grade any more.
    Then send it to cpu() to make sure the tensor isn't on a different device.
    Matplotlib shows images in (height, width, channel) dimention, so the images permute to match the criteria.
    """
    images = tensor.detach().cpu().view(-1, ch, *size)
    grid = make_grid(images[:num_to_display], nrow=4).permute(1, 2, 0)
    plt.axis(False)
    plt.imshow(grid)
    plt.show()

In [None]:
def get_data(data=MNIST, bs=128):
    """
    From torchvision.datasets we can get different datasets.
    For training GANs, we don't need test datasets; just train sets will be enough.
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(H),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            [0.5 for _ in range(C)], [0.5 for _ in range(C)]
        )
    ])
    train_set = MNIST('.',
                      train=True,
                      transform=transform,
                      download=True)

    # group the data in different batch size
    data_loader = DataLoader(train_set, bs, shuffle=True)

    return data_loader

In [None]:
data = get_data(MNIST, BS)
C, H, W = next(iter(data))[0][0].shape

## **Discriminator and Generator Networkrs**
**DCGAN** stands for Deep Convolutional Generative
Adversarial Network. It is a type of GAN that uses convolutional layers in both the generative and discriminative models.## **Discriminator and Generator Networkrs**

In [None]:
# Generator Module
class Generator(nn.Module):
    """
    Generator module for a Generative Adversarial Network (GAN).

    Parameters:
        z_dim (int): Dimensionality of the input noise vector.
        hidden_ch (int): Number of channels in the hidden layers.
        out_ch (int): Number of channels in the output image.

    """
    def __init__(self, z_dim=100, hidden_ch=8, out_ch=C):
        super().__init__()
        self.gen = nn.Sequential(
            self._gen_block(z_dim, hidden_ch*8, 4, 1, 0),
            self._gen_block(hidden_ch*8, hidden_ch*4, 4 , 2 , 1),
            self._gen_block(hidden_ch*4, hidden_ch*2, 4, 2, 1),
            nn.ConvTranspose2d(hidden_ch*2, out_ch, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, noise):
        """
        Forward pass of the generator.

        Parameters:
            noise (torch.Tensor): Input noise tensor.

        Returns:
            torch.Tensor: Output generated image tensor.

        """
        noise = noise.view(BS, Z_DIM, 1, 1)
        return self.gen(noise)

    def _gen_block(self, in_ch, out_ch, kernel, stride, padding):
        """
        Helper function to define a generator block.

        Parameters:
            in_ch (int): Number of input channels.
            out_ch (int): Number of output channels.
            kernel (int): Size of the convolutional kernel.
            stride (int): Stride of the convolutional operation.
            padding (int): Padding size.

        Returns:
            nn.Sequential: Generator block.

        """
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_ch, out_ch, kernel, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

# Discriminator Module
class Discriminator(nn.Module):
    """
    Discriminator module for a Generative Adversarial Network (GAN).

    Parameters:
        img_ch (int): Number of channels in the input image.
        hidden_ch (int): Number of channels in the hidden layers.
        out_dim (int): Dimensionality of the discriminator's output.

    """
    def __init__(self, img_ch=1, hidden_ch=8, out_dim=1):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(img_ch, hidden_ch, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            self._disc_block(hidden_ch, hidden_ch*2, 4, 2, 1),
            self._disc_block(hidden_ch*2, hidden_ch*4, 4, 2, 1),
            self._disc_block(hidden_ch*4, hidden_ch*8, 4, 2, 1),
            nn.Flatten(),
            nn.Linear(hidden_ch * 8 * 2 * 2, out_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        Forward pass of the discriminator.

        Parameters:
            x (torch.Tensor): Input image tensor.

        Returns:
            torch.Tensor: Output probability tensor.

        """
        return self.disc(x).view(-1, 1)

    def _disc_block(self, in_ch, out_ch, kernel, stride, padding):
        """
        Helper function to define a discriminator block.

        Parameters:
            in_ch (int): Number of input channels.
            out_ch (int): Number of output channels.
            kernel (int): Size of the convolutional kernel.
            stride (int): Stride of the convolutional operation.
            padding (int): Padding size.

        Returns:
            nn.Sequential: Discriminator block.

        """
        return nn.Sequential(
            nn.Conv2d(
                in_ch, out_ch, kernel, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True)
        )

def weights_init(m):
    """
    Initialize the weights of the neural network modules.

    Parameters:
        m (nn.Module): Neural network module.

    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Function to generate noise
gen_noise = lambda number, z_dim: torch.randn(number, z_dim).to(device)


In [None]:
# img, _ = next(iter(data))
# pred = disc(img.to(device))
# pred.shape

In [None]:
gen = Generator().to(device)
gen.apply(weights_init)
gen_opt = torch.optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.999))
# I we need to change learning rate during training we can use lr_scheduler function in torch.optim
gen_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(gen_opt, step_size=2, gamma=0.9)

disc = Discriminator().to(device)
disc.apply(weights_init)
disc_opt = torch.optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.999))
disc_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(disc_opt, step_size=2, gamma=0.9)

In [None]:
#Check the data
x, y = next(iter(data))
print("Shape of dataset images: ", x.shape)
print("label of images", y[:16])

noise = gen_noise(BS, Z_DIM)
fake = gen(noise)
show(x, ch=C, size=(H, W))
show(fake, ch=C, size=(H, W))

In [None]:
def gen_loss_func(gen_net, disc_net, loss_func, num, z_dim):
    """
    Calculate the generator loss.

    Parameters:
        gen_net (nn.Module): Generator network.
        disc_net (nn.Module): Discriminator network.
        loss_func: Loss function to compute the loss.
        num (int): Number of samples to generate.
        z_dim (int): Dimension of the noise vector.

    Returns:
        torch.Tensor: Generator loss.

    """
    noise = gen_noise(num, z_dim)
    fake = gen_net(noise)
    pred = disc_net(fake)
    return loss_func(pred, torch.ones_like(pred))


def disc_loss_func(gen_net, disc_net, loss_func, image, num, z_dim):
    """
    Calculate the discriminator loss.

    Parameters:
        gen_net (nn.Module): Generator network.
        disc_net (nn.Module): Discriminator network.
        loss_func: Loss function to compute the loss.
        image (torch.Tensor): Real image tensor.
        num (int): Number of samples to generate.
        z_dim (int): Dimension of the noise vector.

    Returns:
        torch.Tensor: Discriminator loss.

    """
    noise = gen_noise(num, z_dim)
    fake = gen_net(noise)
    fake_pred = disc_net(fake.detach())  # detach() the generator output so it won't participate in gen_net learning
    real_pred = disc_net(image)

    loss_real = loss_func(real_pred, torch.ones_like(real_pred))
    loss_fake = loss_func(fake_pred, torch.zeros_like(fake_pred))

    return (loss_real + loss_fake) / 2


In [None]:
from torch.utils.tensorboard import SummaryWriter
!rm -r /content/runs
writer = SummaryWriter("/content/runs")
writer_fake = SummaryWriter("/content/runs/fake")
writer_real = SummaryWriter("/content/runs/real")

In [None]:
%load_ext tensorboard
%tensorboard --logdir=runs

In [None]:
!rm -r $PATH
PATH = "/content/model/"
!mkdir $PATH

In [None]:
step = 0
for epoch in range(EPOCH):
    discLoss, genLoss = 0, 0
    print(f"\nEpoch: {epoch + 1}")

    for batch, (real, _) in enumerate(tqdm(data)):
        real = real.to(device)
        disc_loss = disc_loss_func(gen, disc, loss_func, real, BS, Z_DIM)
        disc_opt.zero_grad()
        disc_loss.backward(retain_graph=True) # If False, the graph used to compute the grad will be freed, Actually It isnt necessary
        disc_opt.step()

        gen_loss = gen_loss_func(gen, disc, loss_func, BS, Z_DIM)
        gen_opt.zero_grad()
        gen_loss.backward(retain_graph=True)
        gen_opt.step()

        discLoss += disc_loss /len(data)
        genLoss += gen_loss /len(data)
        if batch % 150 == 0 and batch != 0:
            # print(f'  batch: {batch} -- {batch*32} of {len(data.dataset)} sample')
            with torch.no_grad():
                step += 1
                fake = gen(gen_noise(BS, Z_DIM)).view(-1, C, H, W)
                image = real.view(-1, C, H, W)
                real_grid = make_grid(image[:32], normalize=True)
                fake_grid = make_grid(fake[:32], normalize=True)

                writer_fake.add_image(
                    "MNIST fake image", fake_grid, global_step=step
                )
                writer_real.add_image(
                    "MNIST real image", real_grid, global_step=step
                )


    print(f'  Discriminator Loss: {discLoss:.4f} -- Generator Loss: {genLoss:.4f}')
    gen_exp_lr_scheduler.step()
    disc_exp_lr_scheduler.step()
    #Save model
    torch.save(gen.state_dict(), f"{PATH}Gen_{epoch+1}")
    torch.save(disc.state_dict(), f"{PATH}Disc_{epoch+1}")
    #Tensorboard
    writer.add_scalar("Loss/Disc", discLoss, epoch)
    writer.add_scalar("Loss/Gen", genLoss, epoch)

    writer.add_scalar("Loss/Disc", discLoss, epoch)
    writer.add_scalar("Loss/Gen", genLoss, epoch)
    if (epoch + 1) % 2 == 0 and epoch > 0:
        print(f"  >>> Discriminator Learning Rate: {disc_opt.param_groups[0]['lr']}")
        print(f"  >>> Generator Learning Rate: {gen_opt.param_groups[0]['lr']}")

In [None]:
generated = gen(gen_noise(BS, Z_DIM)).view(-1, C, H, W)
show(generated)