# Generative Adversarial Networks (GANs)

The GAN architecture consists of a **generator** network and a **discriminator** network. The former is in charge of generating synthetic images from noise, in order to fool the discriminator. On the other hand, the discriminator network is in charge of telling apart real images from fake ones. This leads to a two player **mini-max adversarial game** between the two. Through this optimization, the generator learns to produce photo-realistic images, similar to the real data.

![Alexnet architecture](https://www.researchgate.net/publication/343597759/figure/fig4/AS:923532934529034@1597198818441/The-architecture-of-the-generator-and-the-discriminator-in-a-DCGAN-model-FSC-is-the.ppm)

In this lab session we will learn how to code a GAN to generate anime faces from noise inputs. Let's start by importing, as usual, the necessary modules.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.optim as optim
from torchvision.utils import save_image

Let's mount the Google Drive folder into our colab environment

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

We need the [tar](https://drive.google.com/file/d/1PrLr4jwkIwHxQJvQaKJNxw5yAReFpJnA/view?usp=sharing) archive in our Google Drive storage (you can do this by simply creating a shortcut by following the link). We will then extract the files. Please change the paths according to your directory structure of Google Drive

In [None]:
# specify the path of the tar in your gdrive
% cp "/content/gdrive/My Drive/datasets/anime-faces.tar.gz" ./

Let's extract!

In [None]:
! tar -xf anime-faces.tar.gz

We now need to define some constants that we will use in our implementation

In [None]:
img_size = 64
root_folder = "/content/anime-faces/"
batch_size = 128
nc = 3 # number of channels in an image (RGB)
ngf = 64 # number of features in the generator
ndf = 64 # number of channels in the discriminator
nz = 100 # dimension of the latent space
lr = 0.0002 # learning rate for the networks
num_epochs = 100 # total number of training epochs
save_dir = "/gan_log/"

For visualization purposes, we will now sample some items from our dataset

In [None]:
# create a set of transforms for the dataset
dset_transforms = list()
dset_transforms.append(transforms.Resize(img_size))
dset_transforms.append(transforms.ToTensor())
dset_transforms.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                                            std=[0.5, 0.5, 0.5]))
dset_transforms = transforms.Compose(dset_transforms)

# create a dataset using ImageFolder of pytorch
dataset = dset.ImageFolder(root=root_folder, transform=dset_transforms)

# create a data loader
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4,
                        shuffle=True, drop_last=True)

# Plot images
fixed_batch = next(iter(dataloader))  # Gets the first batch
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training images")
plt.imshow(np.transpose(vutils.make_grid(fixed_batch[0][:64], padding=2, 
                                         normalize=True), (1, 2, 0)))

We now proceed into a custom initialization of the weights for our generator and discriminator.

In [None]:
def weight_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    nn.init.normal_(m.weight.data, 0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
    nn.init.normal_(m.weight.data, 1.0, 0.02)
    nn.init.constant_(m.bias.data, 0.0)

## Generator
In the upcoming code block, we will define our generator according to the DCGAN fashion. If interested, you may take a look at the [original paper](https://arxiv.org/abs/1511.06434).

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.main = nn.Sequential(
        nn.ConvTranspose2d(in_channels=nz, out_channels=ngf * 8,
                           kernel_size=(4, 4), stride=1, padding=0,
                           bias=False),
        nn.BatchNorm2d(num_features=ngf * 8),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(in_channels=ngf * 8, out_channels=ngf * 4,
                           kernel_size=(4, 4), stride=2, padding=1,
                           bias=False),
        nn.BatchNorm2d(num_features=ngf * 4),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(in_channels=ngf * 4, out_channels=ngf * 2,
                           kernel_size=(4, 4), stride=2, padding=1,
                           bias=False),
        nn.BatchNorm2d(num_features=ngf * 2),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(in_channels=ngf * 2, out_channels=ngf,
                           kernel_size=(4, 4), stride=2, padding=1,
                           bias=False),
        nn.BatchNorm2d(num_features=ngf),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(in_channels=ngf, out_channels=nc,
                           kernel_size=(4, 4), stride=2, padding=1,
                           bias=False),
        nn.Tanh()
    )
  
  def forward(self, x):
    out = self.main(x)
    return out 

Let's apply the custom initialization to our generator, and print its structure.

In [None]:
netG = Generator().cuda()
netG.apply(weight_init)
print(netG)

## Discriminator

The discriminator is a network meant to address the binary classification task of distinguishing whether a given input image is fake or real.

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    
    self.main = nn.Sequential(
        nn.Conv2d(in_channels=nc, out_channels=ndf, kernel_size=(4, 4),
                  stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(in_channels=ndf, out_channels=ndf * 2, kernel_size=(4, 4),
                  stride=2, padding=1, bias=False),
        nn.BatchNorm2d(num_features=ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(in_channels=ndf * 2, out_channels=ndf * 4, kernel_size=(4, 4),
                  stride=2, padding=1, bias=False),
        nn.BatchNorm2d(num_features=ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(in_channels=ndf * 4, out_channels=ndf * 8, kernel_size=(4, 4),
                  stride=2, padding=1, bias=False),
        nn.BatchNorm2d(num_features=ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(in_channels=ndf * 8, out_channels=1, kernel_size=(4, 4),
                  stride=1, padding=0, bias=False),
        nn.Sigmoid()
    )

  def forward(self, x):
    out = self.main(x)
    return out

Let's initialize and print.

In [None]:
netD = Discriminator().cuda()
netD.apply(weight_init)
print(netD)

## Training setup

We will now lay out our training setting, defining the ingredients that we need.

In [None]:
# define the loss criterion
criterion = nn.BCELoss()

# sample a fixed noise vector that will be used to visualize the training
# progress
fixed_noise = torch.randn(64, nz, 1, 1).cuda()

# define the ground truth labels.
real_labels = 1.0 # for the real images
fake_labels = 0.0 # for the fake images

# define the optimizers, one for each network
netD_optimizer = optim.Adam(params=netD.parameters(), lr=lr, betas=(0.5, 0.999))
netG_optimizer = optim.Adam(params=netG.parameters(), lr=lr, betas=(0.5, 0.999))

# sample two fixed noise vectors and do a linear interpolation between them
# to get the intermediate noise vectors. We will generate samples for the interpolated
# noise vectors to see effect of interpolation in the latent space (see later!)
z_1 = torch.randn(1, nz, 1, 1)
z_2 = torch.randn(1, nz, 1, 1)
fixed_interpolate = []
for i in range(64):
  lambda_interp = i / 63
  z_interp = z_1 * (1 - lambda_interp) + lambda_interp * z_2
  fixed_interpolate.append(z_interp)
# fixed_interpolate is (64, nz, 1, 1)
fixed_interpolate = torch.cat(fixed_interpolate, dim=0).cuda()

## Put it all together

We will now create our main function in order to carry out the actual GAN training.

In [None]:
def main():

  # iterations counter
  iters = 0

  # iterate over the number of epochs
  for epoch in range(num_epochs):

    # iterate over the data loader
    for i, data in enumerate(dataloader, 0):
      
      ## Discriminator training ##
      # maximize log(D(x)) + log(1 - D(G(x)))

      # The discriminator will be updated once with the real images
      # and once with the fake images. This is achieved by first computing
      # the gradients with the real images (the first term in the D loss function),
      # and then with the fake images generated by the G (second loss term).
      # Only after that the optimizer.step() will be done, which will update the
      # weights of the D.
      # IMPORTANT to note that when the D is updated, the G is kept frozen.
      # Gradients are calculated with loss.backward().

      # train D with real images
      netD.train()

      # zero the gradients of the discriminator
      netD.zero_grad()

      # data is a Tuple (images, labels). We only need the images
      real_images = data[0].cuda() 
      bs = real_images.shape[0]
      
      # we train the discriminator labelling these as real images
      label = torch.full((bs,), real_labels).cuda()
      output = netD(real_images).view(-1)

      # calculate loss on real images. This encourages D to output 1 for
      # real images
      errD_real = criterion(output, label)

      # calculate gradients for D
      errD_real.backward()

      # track D outputs for real images
      D_x = output.mean().item()

      # train D with fake images: sample a batch of noise vectors
      noise = torch.randn(bs, nz, 1, 1).cuda()

      # generate fake data
      fake_images = netG(noise)

      # we train the discriminator labelling these as fake images
      label.fill_(fake_labels)

      # run the fake images through the discriminator. 
      # IMPORTANT to detach the fake_images because we do not need gradients
      # of the G activations wrt to the G weights.
      output = netD(fake_images.detach()).view(-1)
      
      # calculate loss on the fake images. This encourages D to output 0 for
      # fake images
      errD_fake = criterion(output, label)

      # calculate the gradients for D
      # Note that gradients are not reset between backward() calls, so gradients
      # are summed to the ones computed for real images
      errD_fake.backward()
      errD = (errD_real + errD_fake)

      # track D outputs for fake images
      D_G_x_1 = output.mean().item()

      # update the D weights with the gradients accumulated
      netD_optimizer.step()

      ## Generator training ##
      # minimize log(1 - D(G(x)))
      # But such a formulation provides no gradient during the early stages of
      # training and hence its is reformulated as:
      # maximize log(D(G(x)))

      # during the G training the discriminator D is kept fixed
      netG.train()
      netG.zero_grad()

      # we want to train G to confuse D, so we instatiate the loss between
      # the output of D when called on a fake image and a vector filled with
      # real labels. This way, G is encouraged to produce fake images in such
      # a way as to push D to output 1, i.e. fooling D into believing
      # that the images are real
      label.fill_(real_labels)
      output = netD(fake_images).view(-1)
      errG = criterion(output, label)

      # calculate the gradients for G
      errG.backward()

      # track the outputs for fake images
      D_G_x_2 = output.mean().item()

      # update the G weights with the gradients accumulated
      netG_optimizer.step()

      # print the training losses
      if iters % 100 == 0:
        print('[%3d/%d][%3d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, 
            num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_x_1, D_G_x_2))
      
      # visualize the samples generated by the generatore. 
      # 'gan_log/out' and 'gan_log/interpolate' folders stores the generated 
      # images on fixed_noise and fixed_interpolate noise vectors, respectively
      if (iters % 500 == 0) or (epoch == num_epochs - 1):
        out_dir = os.path.join(save_dir, 'out/')
        os.makedirs(out_dir, exist_ok=True)
        interp_dir = os.path.join(save_dir, 'interpolate/')
        os.makedirs(interp_dir, exist_ok=True)
        netG.eval()
        with torch.no_grad():
          fake_fixed = netG(fixed_noise).cpu()
          save_image(fake_fixed, os.path.join(out_dir, str(iters).zfill(7) + '.png'),
                    normalize=True)
          
          interp_fixed = netG(fixed_interpolate).cpu()
          save_image(interp_fixed, os.path.join(interp_dir, str(iters).zfill(7) + '.png'),
                    normalize=True)
          
          # display the images inline as well
          fig, ax = plt.subplots(1, 2, figsize=(15, 15))
          ax[0].imshow(np.transpose(vutils.make_grid(fake_fixed, padding=2, normalize=True), (1, 2, 0)))
          ax[0].axis('off')
          ax[0].set_title('Random Samples')

          ax[1].imshow(np.transpose(vutils.make_grid(interp_fixed, padding=2, normalize=True), (1, 2, 0)))
          ax[1].axis('off')
          ax[1].set_title('Interpolations')

          plt.show()
      
      iters += 1

Let's make it happen!

In [None]:
if __name__ == '__main__':
  main()