<a href="https://colab.research.google.com/github/ParthKhandelwal/GAN/blob/main/GANs_assignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CPSC 440 Assignment: Generative Adversarial Networks (GANs)

This assignment will serve as a practical introduction to GANs by building a model to generate "fake" images of handwritten digits. We will use the infamous MNIST dataset.

## Setup

Consider setting `hardware accelerator` to `GPU` in `Runtime > Change runtime type` for faster learning. We will use PyTorch for modeling and `matplotlib` to display images.

In [361]:
import numpy as np

import os
%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid

from tqdm import tqdm

import IPython.display
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = torch.Generator('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
# if torch.cuda.is_available():
#   torch.set_default_tensor_type('torch.cuda.FloatTensor')

# upload external file before import
from google.colab import files

The data loading step is already setup for you. The MNIST data set is normalized.

In [362]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))])
data_train = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)


def emptyFolder():
  for file in os.scandir("./generated_images"):
    if file.name.endswith(".png"):
        os.unlink(file.path)
# creating a directory (if not exists) to save epoch images, otherwise clear 
# the folder for any previous results
if not os.path.exists("generated_images"):
    os.makedirs("generated_images")
else:
  emptyFolder()

# 1. A GAN for MNIST
You will build a GAN model that can generate convincing handwritten numbers.

(1.1) [1 point] Discover the dimensions of the dataset.

In [363]:
# TODO: Check the size of the inputs

## BEGIN SOLUTION
data_train.data.size()
## END SOLUTION

torch.Size([60000, 28, 28])

We will take the topdown approach, and implement the GAN model before implementing the inner adversarial networks.

(1.2) [10 points] Configure the Generator

In [364]:
from torch.nn.modules.loss import BCELoss
class Generator(nn.Module):
    def __init__(
        self,
        in_dim,
        learning_rate=0.01,
        batch_size=100
        ):
      super().__init__()

      self.learning_rate = learning_rate
      self.batch_size = batch_size
      self.in_dim = in_dim

      # TODO: build a NN model with in_features=in_dim and out_features=in_dim
      # For MNIST in_dim is 28 * 28 = 784. 
      # TODO: store the model in self.layers.
      # HINT: Only use TanH activation at the end.
      # ANOTHER HINT: In most cases, bigger networks are better networks.
      
      # self.layers = ...
      # self.loss_func = ...
      ## Begin Solution
      self.layers = nn.Sequential(
          nn.Linear(in_dim, 256),
          nn.LeakyReLU(0.2),
          nn.Linear(256, 512),
          nn.LeakyReLU(0.2),
          nn.Linear(512, 1024),
          nn.LeakyReLU(0.2),
          nn.Linear(1024, in_dim),
          nn.Tanh(),
      )
      self.loss_func = nn.BCELoss()
      ## End Solution

      self.optimizer = optim.Adam(self.parameters(), lr = learning_rate)
    
    def forward(self, z):
      # Use the layers you setup in __init__ for forward.
       #z =  z.view(z.size(0), 1, 28, 28)
       return self.layers(z).view(z.size(0), 1, 28, 28).to(device)

    # Helper function to generate inputs for this generator
    def generate_noise(self, batch_size):
      return torch.randn(batch_size, self.in_dim)

    def fit(self, discriminator_output):
      self.optimizer.zero_grad()

      # TODO: fill in the variable 'output' and 'loss' with the correct value
      # loss = ...
      ## BEGIN SOLUTION
      loss = self.loss_func(discriminator_output, torch.ones(discriminator_output.shape))
      ## END SOLUTION 

      loss.backward()
      self.optimizer.step()
      return loss.item()

(1.3) [9 points] Configure the Discriminator

In [365]:
class Discriminator(nn.Module):
    def __init__(
        self,
        in_dim,
        learning_rate=0.01,
        batch_size=100,
    ):
      super().__init__()

      self.learning_rate = learning_rate
      self.batch_size = batch_size
      self.in_dim = in_dim

      # TODO: build an NN model with in_features=784 and out_features=1.
      # HINT: What should the range of the last activation be?
      # ANOTHER HINT: Again, bigger networks are usually better networks.

      # self.layers = ...
      # self.loss_func = ...
      ## BEGIN SOLUTION
      self.layers = nn.Sequential(
            nn.Linear(in_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
      self.loss_func = nn.BCELoss()
      ## END SOLUTION

      self.optimizer = optim.Adam(self.parameters(), lr = learning_rate)

    def forward(self, x):
      # Flatten the input
      x = x.view(x.size(0), self.in_dim).to(device)
      return self.layers(x)

    def fit(self, X_fake, X_real):
      self.optimizer.zero_grad()
      # TODO: train the discriminator using the X_fake parameter generated by 
      # the Generator and the X_real parameter representing the real images 
      # from MNIST dataset

      output = ... # forward
      loss = ... # compute loss
      ## BEGIN SOLUTION
      output = torch.cat((self.forward(X_fake), self.forward(X_real)))
      loss = self.loss_func(
            output, torch.cat((
                torch.zeros((X_fake.shape[0], 1)),
                torch.ones((X_real.shape[0], 1)))) 
      )
      ## END SOLUTION
      
      loss.backward()
      self.optimizer.step()
      return loss.item()

(1.4) [10 points] Implment the GAN using a Generator and a Discriminator. The skeleton code for Generator and Discriminator is availble below.

In [366]:
class GenerativeAdversarialNetwork(nn.Module):
  def __init__(self,
               in_dim,
               learning_rate=0.01,
               epochs=5,
               batch_size=100,
               ):
    super().__init__()
    
    self.learning_rate = learning_rate
    self.batch_size = batch_size
    self.epochs = epochs

    # TODO: create an assign the two adversarial nets using the arguments into the variable 
    # HINT: Refer to the classes below.

    self.G = ... # generator
    self.D = ... # discriminator

    ## Begin Solution
    self.G = Generator(in_dim, learning_rate=self.learning_rate, batch_size=self.batch_size).to(device)
    self.D = Discriminator(in_dim, learning_rate=self.learning_rate, batch_size=self.batch_size).to(device)
    ## End Solution

    # noise to be used for testing the generator across training.
    self.noise = self.G.generate_noise(24)

  def fit(self, X, show_progress = True, download_trainig_image = False):
    loader = DataLoader(dataset=X, batch_size=self.batch_size, shuffle=True, generator=generator)

    X.data.to(device)
    X.targets.to(device)
    X.train_labels.to(device)

    self.G.train()
    self.D.train()

    g_mean_losses = []
    d_mean_losses = []

    for epoch in range(self.epochs):
      g_total_loss = 0
      d_total_loss = 0
      for batch_index, (x, labels) in tqdm(enumerate(loader), total=int(len(X)/self.batch_size)):
        # TODO: Update g_total_loss and d_total_loss
        ...
        ## Begin Solution
        noise = self.G.generate_noise(self.batch_size)
        generator_samples = self.G(noise).detach()
        d_total_loss += self.D.fit(generator_samples, x)
        discriminator_output = self.D(self.G(noise))
        g_total_loss += self.G.fit(discriminator_output)
        ## End Solution
      # TODO: Compute the mean losses per image for this epoch.
      # g_mean_loss = ...
      # d_mean_loss = ...
      ## Begin Solution
      g_mean_loss = g_total_loss / self.batch_size
      d_mean_loss = d_total_loss / self.batch_size
      ## End Solution
      g_mean_losses.append(g_mean_loss)
      d_mean_losses.append(d_mean_loss)
      if download_trainig_image:
        # create the fake image for the epoch
        generated_img = self.G(self.noise).cpu().detach()
        # make the images as grid
        generated_img = make_grid(generated_img)
        # save the generated torch tensor models to disk
        save_image(generated_img, f"./generated_images/gen_img{epoch}.png")
      if show_progress:
        print(f'Epoch: {epoch}/{self.epochs}, G mean loss: {g_mean_loss}, D mean loss: {d_mean_loss}')
    
    return [g_mean_losses, d_mean_losses]
  
  def forward(self, z):
    z.to(device)
    # TODO: Generate an image with the generator and return it.
    ## Begin Solution
    return self.G.forward(z)
    ## End Solution


(1.5) [3 points] Instantiate a `GenerativeAdversarialNetwork`, and train it, and graph the mean losses over epoch.


In [None]:
model = ...
g_mean_losses, d_mean_losses = [..., ...]

## BEGIN SOLUTION
model = GenerativeAdversarialNetwork(in_dim=28*28, epochs=50, learning_rate=0.0002, batch_size=100)
model.to(device)
g_mean_losses, d_mean_losses = model.fit(data_train, True, True)
## END SOLUTION

100%|██████████| 600/600 [00:28<00:00, 21.28it/s]


Epoch: 0/50, G mean loss: 30.85617324203253, D mean loss: 1.7904207569733261


100%|██████████| 600/600 [00:27<00:00, 22.02it/s]


Epoch: 1/50, G mean loss: 14.050667706131936, D mean loss: 3.1394958236813544


100%|██████████| 600/600 [00:28<00:00, 21.36it/s]


Epoch: 2/50, G mean loss: 12.847907614707946, D mean loss: 2.52256423920393


100%|██████████| 600/600 [00:28<00:00, 21.39it/s]


Epoch: 3/50, G mean loss: 18.91187769293785, D mean loss: 1.8025971556454896


100%|██████████| 600/600 [00:31<00:00, 18.83it/s]


Epoch: 4/50, G mean loss: 20.09537203788757, D mean loss: 1.4632513565570116


100%|██████████| 600/600 [00:27<00:00, 21.73it/s]


Epoch: 5/50, G mean loss: 16.75461251497269, D mean loss: 1.6426751981675625


100%|██████████| 600/600 [00:27<00:00, 21.80it/s]


Epoch: 6/50, G mean loss: 16.967262288331984, D mean loss: 1.6551528795808554


100%|██████████| 600/600 [00:27<00:00, 21.75it/s]


Epoch: 7/50, G mean loss: 15.136493203639985, D mean loss: 1.8627769979834556


100%|██████████| 600/600 [00:28<00:00, 21.22it/s]


Epoch: 8/50, G mean loss: 13.705176984071732, D mean loss: 1.9981045046448707


100%|██████████| 600/600 [00:27<00:00, 21.95it/s]


Epoch: 9/50, G mean loss: 12.03156141757965, D mean loss: 2.247503556907177


100%|██████████| 600/600 [00:27<00:00, 21.85it/s]


Epoch: 10/50, G mean loss: 11.087977709770202, D mean loss: 2.390969995111227


100%|██████████| 600/600 [00:27<00:00, 21.95it/s]


Epoch: 11/50, G mean loss: 11.073222463130952, D mean loss: 2.429570034146309


100%|██████████| 600/600 [00:28<00:00, 21.33it/s]


Epoch: 12/50, G mean loss: 10.970747635364532, D mean loss: 2.507267704308033


100%|██████████| 600/600 [00:27<00:00, 21.89it/s]


Epoch: 13/50, G mean loss: 11.246565163731574, D mean loss: 2.4604398849606515


100%|██████████| 600/600 [00:27<00:00, 21.92it/s]


Epoch: 14/50, G mean loss: 11.439666665792465, D mean loss: 2.4406089921295644


100%|██████████| 600/600 [00:27<00:00, 21.99it/s]


Epoch: 15/50, G mean loss: 10.72113300204277, D mean loss: 2.5512314726412297


100%|██████████| 600/600 [00:27<00:00, 21.95it/s]


Epoch: 16/50, G mean loss: 10.96485235452652, D mean loss: 2.4203338894248008


100%|██████████| 600/600 [00:28<00:00, 21.12it/s]


Epoch: 17/50, G mean loss: 11.237904928922653, D mean loss: 2.395324236601591


100%|██████████| 600/600 [00:27<00:00, 21.96it/s]


Epoch: 18/50, G mean loss: 9.991245486736297, D mean loss: 2.7182004007697107


100%|██████████| 600/600 [00:27<00:00, 21.79it/s]


Epoch: 19/50, G mean loss: 9.968625569343567, D mean loss: 2.636398644447327


100%|██████████| 600/600 [00:27<00:00, 21.77it/s]


Epoch: 20/50, G mean loss: 9.654106323719025, D mean loss: 2.821571376025677


100%|██████████| 600/600 [00:27<00:00, 21.58it/s]


Epoch: 21/50, G mean loss: 9.391401609778404, D mean loss: 2.7369176393747328


100%|██████████| 600/600 [00:27<00:00, 21.54it/s]


Epoch: 22/50, G mean loss: 9.272553932070732, D mean loss: 2.8136400523781777


100%|██████████| 600/600 [00:27<00:00, 21.78it/s]


Epoch: 23/50, G mean loss: 9.42173814535141, D mean loss: 2.769018847346306


100%|██████████| 600/600 [00:27<00:00, 21.86it/s]


Epoch: 24/50, G mean loss: 9.001519318819046, D mean loss: 2.870273160934448


 78%|███████▊  | 465/600 [00:21<00:05, 23.40it/s]

(1.6) [1 points] Plot the losses over epoch, and save it as an image.


In [None]:
## BEGIN SOLUTION
plt.figure()
plt.plot(g_mean_losses, label='Generator mean loss')
plt.plot(d_mean_losses, label='Discriminator mean Loss')
plt.legend()
plt.show()
## END SOLUTION

(1.7) [5 points] Generate an image with the model, and display.

In [None]:
image = ...
## Begin Solution
image = model.forward(model.G.generate_noise(20)).cpu().detach()
## End Solution
grid_img = make_grid(image)
save_image(grid_img, "final_image.png")
plt.imshow(grid_img.permute(1, 2, 0))

# 2. Short Answer Questions

(2.1) [2 points] What is the main difference between GANs and variational autoencoders (VAEs)?

(2.2) [3 points] What is a possible cause of the low output diversity of GANs?

(2.3) [2 points] Why can't pre-trained discriminators be used in GANs?