# Generative Adversarial Nets

This notebook will take you through implementing your first **Generative Adversarial Net**, or **GAN** for short. We will then train it on the MNIST dataset and try to generate some realistic handdrawn numbers. Lets start by implementing the required libraries. This notebook is inspired by the original GAN paper which you can find [here](https://arxiv.org/abs/1406.2661).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vutils
import torch
import torch.nn.functional as F
import tqdm
import os
from datetime import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("Using device: ", device)

Let's start by downloading the data and taking a closer look at it!

## Digits, digits and more digits

In [3]:
import torch.utils
import torch.utils.data

# we download the data and convert each image into a tensor
# the images are gray-scale (one channel), and have values between 0 and 1
# representing depending on how bright they are
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root='./data', 
                          train=True,
                          download=True,
                          transform=transform)

# We can use the DataLoader class to create a generator object that will yield batches of images and labels.
# Depending on your hardware, you may want to adjust the batch size to fit your GPU or CPU memory.
BATCH_SIZE = 256
dl = torch.utils.data.DataLoader(dataset=train_ds,
                shuffle=True,
                batch_size=BATCH_SIZE)

Here is a simple code snippet for showing the images. Run it and see how they look.

In [None]:
images, labels = next(iter(dl))

plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(images.to('cpu')[:64], padding=2, normalize=True).cpu(),(1,2,0)));

Well, they look like handwritten digits to me. With our data loaded, let's define our two networks.

## A tale of two networks

Let us begin with implementing our discriminator. This network recieves and image as input and outputs a single scalar, the probability of that image being from the real distribution, denoted $p_{data}$. If the image is from our generated distribution, $p_g$, we want our discriminator to output 0 and if they are from $p_{data}$ (or at least that what the discriminator wants, in reality we want our generator to produce results that are so good that the discriminator can't distinguish them, and thus outputs 0.5 regardless of where the image is from). 

### The Discriminator

In [5]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Discriminator, self).__init__()

        # Define the architecture of the discriminator
        ### START CODE HERE ###

        ### END CODE HERE ###

    def forward(self, x):

        # Define the forward pass. Remember to return the output of the last layer
        # Hint:     it might be useful to use the view method to flatten the input tensor
        #           if you are using fully connected layers.
        # Hint 2:   it might be useful to use the sigmoid function to squash the output of the last layer.
        #           Alternatively, you can use the BCEWithLogitsLoss loss function, which combines the sigmoid
        #           and the binary cross entropy loss in a single class later when defining your loss.
        ### START CODE HERE ###

        ### END CODE HERE ###
        return

### The Generator

In [6]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        
        # Define the architecture of the generator
        ### START CODE HERE ###

        ### END CODE HERE ###

    def forward(self, x):
            
            # Define the forward pass. Remember to return the output of the last layer
            # Hint:     remember that we want the output of the generator to match
            #           the size of the input of the discriminator (the size of the images).
            ### START CODE HERE ###
    
            ### END CODE HERE ###
            return

Let us take a look at how many trainable parameters we have in our two network. This balance is important, as we want there to be a "fair battle" between our networks. Generally, we want our generator network to have more parameters.

In [None]:
d = Discriminator(784, 1).to(device)
g = Generator(100, 784).to(device)

# count the number of trainable parameters
print("NUMBER OF TRAINABLE PARAMETERS")
d_trainable_params = sum(p.numel() for p in d.parameters() if p.requires_grad)
g_trainable_params = sum(p.numel() for p in g.parameters() if p.requires_grad)
print("Discriminator: ", d_trainable_params)
print("Generator: ", g_trainable_params)
print("Fraction of Generator parameters: ", g_trainable_params / (d_trainable_params + g_trainable_params))

## Training the GAN

In [43]:
def train(generator, discriminator, dl, latent_size, epochs=100, device='cpu'):
    g_losses = []
    d_losses = []

    # Define the loss function and the optimizers for the generator and discriminator
    ### START CODE HERE ###
    criterion = None
    d_optimizer = None
    g_optimizer = None
    ### END CODE HERE ###

    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    save_dir = f'models/{time_stamp}'
    os.makedirs(save_dir, exist_ok=True)

    for epoch in tqdm.tqdm(range(epochs)):
        generator.train()
        discriminator.train()

        d_running_loss = 0.0
        g_running_loss = 0.0

        for i, data in enumerate(dl):

            batch_size = data[0].size(0)

            ########################################################
            # Train the discriminator
            #
            # The discriminator is trained to classify images as real or fake.
            # We do this by training the discriminator on real images and then on fake images.
            # The loss for the discriminator is the sum of the losses for the real and fake images.
            ########################################################
            d_optimizer.zero_grad()

            # Train the discriminator on real images
            ### START CODE HERE ###

            real_images, _ = data
            real_images = real_images.to(device)
            
            real_labels = None # create labels for real images
            d_real_outputs = None # pass real images through the discriminator
            d_real_loss = criterion(d_real_outputs, real_labels)

            ### END CODE HERE ###

            # Train the discriminator on fake images
            ### START CODE HERE ###
            z = torch.rand(batch_size, latent_size).to(device) * 2 - 1
            with torch.no_grad():
                d_fake_images = None # generate fake images

            fake_labels = None
            d_fake_outputs = None
            d_fake_loss = criterion(d_fake_outputs, fake_labels)

            ### END CODE HERE ###

            # Backpropagate the loss and update the weights
            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            d_optimizer.step()

            d_running_loss += d_loss.item()

            ########################################################
            # Train the generator
            #
            # The generator is trained to generate images that are classified as real by the discriminator.
            # We do this by generating images and then feeding them to the discriminator and classifying them as real.
            # The loss for the generator is the opposite of the loss for the discriminator.
            # This is because the generator wants to fool the discriminator into thinking the images are real.
            ########################################################
            g_optimizer.zero_grad()

            ### START CODE HERE ###

            z = None
            g_fake_images = None

            real_labels = None
            g_real_outputs = None
            g_loss = criterion(g_real_outputs, real_labels)

            ### END CODE HERE ###

            # Backpropagate the loss and update the weights
            g_loss.backward()
            g_optimizer.step()

            g_running_loss += g_loss.item()

        d_losses.append(d_running_loss / len(dl))
        g_losses.append(g_running_loss / len(dl))

        if epoch % 10 == 0:
            save_path = save_dir + f'/epoch_{epoch}'
            os.makedirs(save_path, exist_ok=True)
            # Save the models
            torch.save(generator.state_dict(), save_path + '/generator.pth')
            torch.save(discriminator.state_dict(), save_path + '/discriminator.pth')
    save_path = save_dir + '/final'
    os.makedirs(save_path, exist_ok=True)
    # Save the models
    torch.save(generator.state_dict(), save_path + '/generator.pth')
    torch.save(discriminator.state_dict(), save_path + '/discriminator.pth')
    return g_losses, d_losses



Finally, let us train the model. Depending on your hardware and batch size, an epoch may take anywhere between a couple of seconds and a couple of minutes 😳

In [None]:
# Feel free to experiment with the latent size and the number of epochs
latent_size = 100
epochs = 200


g = None # create the generator by making an instance of the Generator class
d = None # create the discriminator by making an instance of the Discriminator class
# Hint: you can use the 'to' method to move the models to the device, e.g. g.to(device)


g_losses, d_losses = train(g, d, dl, latent_size=latent_size, epochs=epochs, device=device)

In [None]:
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Losses')

## Generate some images

In [None]:
generator_path = 'models/2024-10-01_15-02-08/epoch_180/generator.pth'

g = Generator(100, 784).to(device)
g.load_state_dict(torch.load(generator_path))

In [None]:
g.eval()

z = torch.randn(64, 100).to(device)
images = g(z).view(-1, 28, 28).detach().cpu()
images = images.unsqueeze(1)

plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(images.to('cpu')[:64], padding=2, normalize=True).cpu(),(1,2,0)));