# What are GANs ?
* **GANs (Generative adversarial Networks)** is a class of machine learning techniques and models used for generating new data samples resembling a given dataset.
* They consists of two neural networks- a **Generator** and a **Discriminator**, that competes against eachother in an **adversarial** game-like framework.

# How GANs work ?
* **Generator(G):** Tries to create realistic fake data.
* **Discriminator(D):** Evaluates whether the input data is real(from the actual dataset) or fake (generated by G). This outputs a single scalar.
* **Adversarial Training:** Generator improves by trying to fool Discriminator, while Discriminator gets better at detecting fakes. Over time, the Generator produces imcreasingly realistic outputs.

#### Note
Both the **Discriminator** and the **Generator** start from scratch, meaning they are both *randomly initialized* at the start and then trained simultaneously.

# Training of a Simple GAN
In order to implement our very first simple GAN, we need to understand the loss functions of **Generator** and **Discriminator** and how they are used for training.

## Training the Discriminator
The **Discriminator** needs to be trained in a way that makes it capable to distinguish between the real and fake(generated) images.

### Discriminator Loss
The Discriminator's objective is to **maximize** the following loss function:
$$ L_{D} = \frac{1}{m} \sum \limits _{i=1}^{m} [ \log{D(x^i)} + \log(1-D(G(z^i))) ]$$
where,\
* $x^i$: real image from the dataset\
* $z^i$: random noise input to the Generator\
* $G(z^i)$: fake image generated by the Generator
#### Understanding the two terms
* $ \log D(x^i)$
    * Encourages the Discriminator to classify **real images correctly** ($D(x)$ &rarr; $1$).
* $\log(1-D(G(z^i)))$
    * Encourages the Discriminator to classify **fake images correctly** ($D(G(z))$ &rarr; $0$).

Instead of maximizing $L_D$, we **minimize** its **negative** since optimizers(like Adam or SGD) perform **minimization** by default.

## Training the Generator
The **Generator(G)** tries to produce fake images that are as realistic as possible, **fooling** the Discriminator into classifying them as real.

### Standard Generator Loss (Saturating Version)
Initially, we try to minimize:
$$L_G = \frac{1}{m} \sum \limits _{i=1}^{m} \log(1-D(G(z^i)))$$

**Issue with this loss function:**
* If the **Discriminator becomes too good**, $ùê∑(ùê∫(ùëß))$ will be close to $0$.
* Since $\log(1‚àíD(G(z)))$ flattens near $0$, gradients become **very small** (vanishing gradient problem).
* This slows down learning for the Generator.

### Standard Generator Loss (Non-Saturating Version)
Instead of minimizing $\log(1-D(G(z^i)))$, we maximize:
$$\log D(G(z))$$
This is equivalent to minimizing:
$$L_G = - \log D(G(z))$$
* This approach provides **stronger gradients**, allowing the Generator to learn **faster**.

#### Note
Since both losses involve **binary classification** (real vs. fake), we use **Binary Cross-Entropy Loss (BCELoss)**.

# Implementing a Simple GAN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
lr = 3e-4
z_dim = 64
image_dim = (1,28,28)
batch_size = 32
num_epochs = 100

In [None]:
data_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5,)),
    ]
)

In [None]:
dataset = datasets.MNIST(root="dataset/", transform=data_transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
class Discriminator(nn.Module):
  def __init__(self,in_features):
    super().__init__()
    channels, height, width = in_features
    self.disc = nn.Sequential(
        nn.Conv2d(channels,128,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.01),
        nn.Conv2d(128,1,kernel_size=3, stride=1,padding=1),
        nn.Sigmoid()
    )

  def forward(self,x):
    return self.disc(x)

class Generator(nn.Module):
    def __init__(self, random_noise_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(random_noise_dim, 128, kernel_size=7, stride=1, padding=0),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.01),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1, 1, 1)  # Reshape noise input to match ConvTranspose2d input
        return self.gen(x)

In [None]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim).to(device)

In [None]:
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(),lr=lr)
loss_fn = nn.BCELoss()

In [None]:
for epoch in range(num_epochs):
  for batch_idx, (real,_) in enumerate(loader):
    real = real.to(device)
    batch_size = real.shape[0]

    random_noise = torch.randn(batch_size,z_dim,1,1).to(device)
    fake = gen(random_noise)
    disc_real = disc(real).view(-1)
    lossD_real = loss_fn(disc_real,torch.ones_like(disc_real))

    disc_fake = disc(fake).view(-1)
    lossD_fake = loss_fn(disc_fake,torch.zeros_like(disc_fake))
    
    lossD = (lossD_real + lossD_fake)/2
    
    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()

    output = disc(fake).view(-1)
    lossG = loss_fn(output,torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )