# 54: Generative Adversarial Networks (GANs)

## ðŸŽ¯ Objective
In this notebook, we embark on one of the most exciting topics in Deep Learning: **Generative Adversarial Networks (GANs)**. Instead of classifying existing data, we will train a model to **create** new data.

We will build two networks that compete against each other:
1.  **The Generator:** Tries to create fake handwritten digits (MNIST) that look real.
2.  **The Discriminator:** Tries to distinguish between real images from the dataset and fake images from the generator.

By the end, we hope to see our Generator produce recognizable digits from pure random noise.

## ðŸ“š Key Concepts
* **Adversarial Training:** A min-max game where two networks optimize opposite goals. The Generator minimizes the probability that the Discriminator catches it, while the Discriminator maximizes its accuracy in catching fakes.
* **Latent Space:** A vector of random numbers (noise) that serves as the input to the Generator. The Generator learns to map this random noise to meaningful data (images).
* **LeakyReLU:** A standard activation function in GANs that prevents "dying ReLUs" and allows gradients to flow more easily during the unstable training process.
* **Tanh Activation:** Often used as the final layer of the Generator to map pixel values to the range [-1, 1].

## 1. Import Libraries

We import PyTorch for building the networks and standard libraries for data manipulation and plotting.

In [None]:
# import libraries
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import sys

import matplotlib.pyplot as plt
from IPython import display
display.set_matplotlib_formats('svg')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 2. Data Preparation

We load the MNIST dataset. 

### Normalization Strategy
Standard image data is often [0, 1]. However, GAN Generators typically use a **`Tanh`** activation function in the final layer, which outputs values in the range **[-1, 1]**. Therefore, we must normalize our real training images to match this range so the Discriminator has a fair comparison.

In [None]:
# import dataset (comes with colab!)
data = np.loadtxt(open('sample_data/mnist_train_small.csv','rb'),delimiter=',')

# don't need the labels here
data = data[:,1:]

# normalize the data to a range of [-1 1] (b/c tanh output)
dataNorm = data / np.max(data)
dataNorm = 2*dataNorm - 1

# convert to tensor
dataT = torch.tensor( dataNorm ).float()

# no dataloaders!
batchsize = 100

## 3. Model Architecture

We define our two competing networks.

### The Discriminator
This is a standard binary classifier. 
* **Input:** An image (flattened to 784 pixels).
* **Output:** A single probability score (0 = Fake, 1 = Real).
* **Activation:** We use `LeakyReLU` to allow small gradients for negative values, which stabilizes GAN training.

In [None]:
class discriminatorNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.fc1 = nn.Linear(28*28,256)
    self.fc2 = nn.Linear(256,256)
    self.out = nn.Linear(256,1)

  def forward(self,x):
    x = F.leaky_relu( self.fc1(x) )
    x = F.leaky_relu( self.fc2(x) )
    x = self.out(x)
    return torch.sigmoid( x )

dnet = discriminatorNet()
y = dnet(torch.randn(10,784))
y

### The Generator
This network does the reverse of a classifier.
* **Input:** A noise vector of size 64 (latent variable).
* **Output:** A flattened image (784 pixels).
* **Activation:** The final activation is `Tanh` to ensure pixel values are between -1 and 1.

In [None]:
class generatorNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.fc1 = nn.Linear(64,256)
    self.fc2 = nn.Linear(256,256)
    self.out = nn.Linear(256,784)

  def forward(self,x):
    x = F.leaky_relu( self.fc1(x) )
    x = F.leaky_relu( self.fc2(x) )
    x = self.out(x)
    return torch.tanh( x )


gnet = generatorNet()
y = gnet(torch.randn(10,64))
plt.imshow(y[0,:].detach().squeeze().view(28,28));

## 4. Training Setup

We need two optimizers: one for the Discriminator (`d_optimizer`) and one for the Generator (`g_optimizer`). They learn independently but rely on each other's outputs.

We use **Binary Cross Entropy Loss (`BCELoss`)** because the Discriminator is performing a binary classification task (Real vs. Fake).

In [None]:
# loss function (same for both phases of training)
lossfun = nn.BCELoss()

# create instances of the models
dnet = discriminatorNet().to(device)
gnet = generatorNet().to(device)

# optimizers (same algo but different variables b/c different parameters)
d_optimizer = torch.optim.Adam(dnet.parameters(), lr=.0003)
g_optimizer = torch.optim.Adam(gnet.parameters(), lr=.0003)

## 5. The GAN Training Loop

This loop is more complex than standard training. In each epoch, we perform two distinct training steps:

1.  **Train the Discriminator:**
    * Show it real data with label **1**.
    * Show it fake data (from Generator) with label **0**.
    * Calculate loss and update *only* the Discriminator's weights.

2.  **Train the Generator:**
    * Generate new fake data.
    * Pass it through the Discriminator.
    * **Crucial Trick:** Calculate loss using label **1** (Real). We are punishing the Generator if the Discriminator successfully identified the image as fake (0). We want the Discriminator to think it's real (1).
    * Update *only* the Generator's weights.

In [None]:
# this cell takes ~3 mins with 50k epochs
num_epochs = 50000

losses  = np.zeros((num_epochs,2))
disDecs = np.zeros((num_epochs,2)) # disDecs = discriminator decisions

for epochi in range(num_epochs):

  # create minibatches of REAL and FAKE images
  randidx     = torch.randint(dataT.shape[0],(batchsize,))
  real_images = dataT[randidx,:].to(device)
  fake_images = gnet( torch.randn(batchsize,64).to(device) ) # output of generator


  # labels used for real and fake images
  real_labels = torch.ones(batchsize,1).to(device)
  fake_labels = torch.zeros(batchsize,1).to(device)



  ### ---------------- Train the discriminator ---------------- ###

  # forward pass and loss for REAL pictures
  pred_real   = dnet(real_images)              # REAL images into discriminator
  d_loss_real = lossfun(pred_real,real_labels) # all labels are 1

  # forward pass and loss for FAKE pictures
  pred_fake   = dnet(fake_images)              # FAKE images into discriminator
  d_loss_fake = lossfun(pred_fake,fake_labels) # all labels are 0

  # collect loss (using combined losses)
  d_loss = d_loss_real + d_loss_fake
  losses[epochi,0]  = d_loss.item()
  disDecs[epochi,0] = torch.mean((pred_real>.5).float()).detach()

  # backprop
  d_optimizer.zero_grad()
  d_loss.backward()
  d_optimizer.step()




  ### ---------------- Train the generator ---------------- ###

  # create fake images and compute loss
  fake_images = gnet( torch.randn(batchsize,64).to(device) )
  pred_fake   = dnet(fake_images)

  # compute and collect loss and accuracy
  g_loss = lossfun(pred_fake,real_labels)
  losses[epochi,1]  = g_loss.item()
  disDecs[epochi,1] = torch.mean((pred_fake>.5).float()).detach()

  # backprop
  g_optimizer.zero_grad()
  g_loss.backward()
  g_optimizer.step()


  # print out a status message
  if (epochi+1)%500==0:
    msg = f'Finished epoch {epochi+1}/{num_epochs}'
    sys.stdout.write('\r' + msg)

## 6. Visualization

We visualize the training progress and the final outputs.

### Understanding the Plots
* **Losses:** GAN losses often do not converge to zero. Instead, they oscillate. If one goes to 0, the other network has "won" too hard, and learning usually stops.
* **Discriminator Output:** Ideally, for a perfect Generator, the Discriminator should output 0.5 for everything (it's guessing randomly).

In [None]:
fig,ax = plt.subplots(1,3,figsize=(18,5))

ax[0].plot(losses)
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].set_title('Model loss')
ax[0].legend(['Discrimator','Generator'])
# ax[0].set_xlim([4000,5000])

ax[1].plot(losses[::5,0],losses[::5,1],'k.',alpha=.1)
ax[1].set_xlabel('Discriminator loss')
ax[1].set_ylabel('Generator loss')

ax[2].plot(disDecs)
ax[2].set_xlabel('Epochs')
ax[2].set_ylabel('Probablity ("real")')
ax[2].set_title('Discriminator output')
ax[2].legend(['Real','Fake'])

plt.show()

## 7. Generating Fake Digits

The moment of truth: We feed random noise into our trained Generator and view the output. The results should resemble handwritten digits, although they might look a bit fuzzy or strange.

In [None]:
# generate the images from the generator network
gnet.eval()
fake_data = gnet(torch.randn(12,64).to(device)).cpu()

# and visualize...
fig,axs = plt.subplots(3,4,figsize=(8,6))
for i,ax in enumerate(axs.flatten()):
  ax.imshow(fake_data[i,:,].detach().view(28,28),cmap='gray')
  ax.axis('off')

plt.show()

## 8. Additional Explorations

GANs are notoriously difficult to train. Try these experiments to understand why:
1.  **Batch Normalization:** Try adding batchnorm layers. In this simple linear GAN, it might not help much, but it's crucial for Deep Convolutional GANs (DCGANs).
2.  **Learning Rate:** GANs are very sensitive to learning rates. Try increasing or decreasing it and observe if the Generator collapses (outputs garbage) or if the Discriminator becomes too powerful too quickly.