# Generative Adversarial Networks

GAN is a deep neural network archiecture which comprises of a generator (G) and a discriminator (D). Both G and D are neural networks. The generator learns to generate samples as close as possible to the real data while the discriminator learns to distinguish the generated samples from real data. That's why it called to be adversarial. Unlike in VAE, we don't have an explicit density function in this two-player game. So GAN is mainly used to generate realistic samples while VAE is also used to extract a representation of the input. The architecture of GAN is like below:

<img src="src/gan.png" width="50%" height="50%" />

The generator will output a data instance while the discriminator will output a value in (0, 1) which represents the likelihood that the input is real data. So the discriminator wants to maximize the output for the real data and minimize the output for the generated data while the generator wants to maximize the discriminator's output for the generated data. As a result, the optimization is a minimax game

$$
\min\limits_{\theta_g}\max\limits_{\theta_d} \big[\mathbb{E}_{x\sim p_{data}}\log D_{\theta_d}(x) + \mathbb{E}_{z\sim p(z)} \log(1 - D_{\theta_d}(G_{\theta_g}(z)))\big]
$$

Okay! Let's move on to the code!

In [21]:
# import packages
import os
import random
from easydict import EasyDict as edict

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [22]:
# initialize parameters and dataset
args = edict()
args.batch_size = 64
args.epochs = 25
args.lr = 0.0002
args.nz = 100  # size of input vector to generator
args.ngf = 64  # base channels used in generator
args.ndf = 64  # base channels used in discriminator
args.input_size = 64  # resize input image 
args.real_label = 1
args.fake_label = 0
args.beta1 = 0.5  # for optimizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data/MNIST', download=True, 
                   transform=transforms.Compose([
                       transforms.Resize(args.input_size),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,)),
                   ])),
    batch_size=args.batch_size, shuffle=True)

In [23]:
# weight initialization for generator and discriminator from dcgan paper
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [24]:
# define generator and discriminator 
class Generator(nn.Module):
    def __init__(self, nz, ngf):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, 1, 4, 2, 1, bias=False),
            nn.Tanh()
            # 1 x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 1 x 64 x 64
            nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            #  (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # 1 x 1 x 1
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1).squeeze(1)
    

The configuration in generator and discriminator are based on DCGAN paper. The suggestions from the paper include:
1. Replace all max pooling with convolutional stride
2. Use transposed convolution for upsampling.
3. Eliminate fully connected layers.
4. Use Batch normalization except the output layer for the generator and the input layer of the discriminator.
5. Use ReLU in the generator except for the output which uses tanh.
6. Use LeakyReLU in the discriminator. 

In [25]:
# create generator, discriminator, criterion, optimizer
netG = Generator(args.nz, args.ngf).to(device)
netG.apply(weights_init)

netD = Discriminator(args.ndf).to(device)
netD.apply(weights_init)

criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999))


In [None]:
# Do Training

# generate an image based on fixed_noise to check the progress of training
fixed_noise = torch.randn(args.batch_size, args.nz, 1, 1, device=device)  

for epoch in range(args.epochs):
    print("now training %dth epoch" % epoch)
    for i, data in enumerate(data_loader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_input = data[0].to(device)
        batch_size = real_input.size(0)
        label = torch.full((batch_size,), args.real_label, device=device)

        output = netD(real_input)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, args.nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(args.fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(args.real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
    # generate image to check progress
    fake = netG(fixed_noise)
    save_image(fake.detach(), 'GAN_result/sample_%03d.png' % (epoch), normalize=True)

# save model after training
torch.save(netG.state_dict(), 'GAN_result/netG_epoch.pth')
torch.save(netD.state_dict(), 'GAN_result/netD_epoch.pth')

One interesting thing is that, if you train the model for more epochs, you may find that the generated image become worse. This is because GAN is not stable. You may add noise to the input of discriminator to mitigate this problem. How to make GAN more stable is still a hot topic. You could refer to more GAN variants to get more information! Enjoy!