In [None]:
%matplotlib inline

Attribution: 

   * Most material is adapted from [Agustinus Kristiadi's blog](https://wiseodd.github.io/techblog/2017/01/20/gan-pytorch/) and [Generative models repository](https://github.com/wiseodd/generative-models)

# Generative Adversarial Networks

We're now doing to use PyTorch to implement a *vanilla* version of one of the most popular generative models today, the GAN.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import os

import torchvision
import torch.utils.data as Data

In [None]:
mnist_train = torchvision.datasets.MNIST(
    root='./data/',
    train=True,                                     # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalizes in the range [0.0, 1.0]
    download=True,                                  # download it if you don't have it
)

mb_size = 64
Z_dim = 100
n_r = mnist_train.train_data.shape[1]
n_c = mnist_train.train_data.shape[2]
X_dim =  n_r * n_c
h_dim = 128
lr = 1e-3
c = 0
n_samples = mb_size # used for visualization during learning

assert(n_samples <= mb_size)  # this is assumed below

train_loader = Data.DataLoader(dataset=mnist_train, batch_size=mb_size,
                               shuffle=True)

## Meet the Generator and Discriminator

A GAN has two crucial parts, the Generator $G(\mathbf{z})$ (which we are ultimately interested in), and a Discriminator $D(\mathbf{x})$ which is used to train the Generator and then discarded.

Let's start by constructing $G(\mathbf{z})$

We are going to use a popular type of weight initialization commonly called "Xavier" initialization after the first author of this paper:

    Glorot, Xavier, and Yoshua Bengio. 2010. “Understanding the Difficulty of
    Training Deep Feedforward Neural Networks.” In Proceedings of the
    Thirteenth International Conference on Artificial Intelligence and
    Statistics, 249–56.

In [None]:
Wzh = torch.empty(Z_dim, h_dim)
nn.init.xavier_normal_(Wzh)
Wzh = Variable(Wzh, requires_grad=True)
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = torch.empty(h_dim, X_dim)
nn.init.xavier_normal_(Whx)
Whx = Variable(Whx, requires_grad=True)
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

def G(z):
    h = F.relu(z.mm(Wzh) + bzh)
    X = torch.sigmoid(h.mm(Whx) + bhx)
    return X

In essence, the Generator takes a noise vector as input and maps it through a ReLU hidden layer to the image domain.

Now, we define $D(\mathbf{x})$ which is just a classifier as we have seen before:

In [None]:
Wxh = torch.empty(X_dim, h_dim)
nn.init.xavier_normal_(Wxh)
Wxh = Variable(Wxh, requires_grad=True)
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = torch.empty(h_dim, 1)
nn.init.xavier_normal_(Why)
Why = Variable(Why, requires_grad=True)
bhy = Variable(torch.zeros(1), requires_grad=True)

def D(X):
    h = F.relu(X.mm(Wxh) + bxh)
    y = F.sigmoid(h.mm(Why) + bhy)
    return y

Since each of these networks will be optimized according to their own criterion, we have deliberately kept them separate.

This requires a little extra bookkeeping:

In [None]:
G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params

def reset_grad():
    for p in params:
        if p.grad is not None:
            p.grad.data.zero_()

Now let's set up the optimization procedure:

In [None]:
G_solver = optim.Adam(G_params, lr=lr)
D_solver = optim.Adam(D_params, lr=lr)

# these will be targets used to denote "real" or "fake" data
ones_label = Variable(torch.ones(mb_size))
zeros_label = Variable(torch.zeros(mb_size))

## Forward and backward pass

Now we will wrap up each of the Generator and Discriminator's forward and backward steps in a loop over mini-batches and epochs, but let's first describe the forward and backward pass on a single mini-batch.

Let's grab a batch of data for demonstration purposes:

In [None]:
# noise as input to the Generator
z = Variable(torch.randn(mb_size, Z_dim))

# data as input to the discriminator
X, _ = next(iter(train_loader))
X = Variable(X.view(-1, X_dim))

We'll show the Discriminator's forward and backward pass first.

We pass some noise through the Generator to create some "fake" data. We assign labels that say "fake" to the "fake" data and labels that say "real" to the batch of data we took from the training set.

Then the Discriminator tries to assign "real" to the real data and "fake" to the data coming from the Generator.

In [None]:
# Discriminator forward-loss-backward-update
G_sample = G(z)

D_fake = D(G_sample)
D_real = D(X)

D_loss_real = F.binary_cross_entropy(D_real, ones_label.view(-1, 1))
D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label.view(-1, 1))
D_loss = D_loss_real + D_loss_fake

Now that we've computed the loss we can backprop through the Discriminator and update its parameters:

In [None]:
D_loss.backward()
D_solver.step()

Because we have two different optimizers, we need to clear the computed gradient in our computational graph as we do not need it anymore. This is important, because there will be a subsequent call of `backward()` on the Generator, and `D_solver` shares some subgraphs with `G_solver`.

In [None]:
reset_grad()

Now, let's look at the forward and backward pass for the Generator.

Again, we sample some noise, and pass it through the Generator to create some "fake" data.

Now we apply it to the Discriminator. Except that from the Generator's perspective it wants to *fool* the Discriminator, so we apply "real" labels as ground truth.

We backprop through the Generator, update its parameters, and clear the gradients.

In [None]:
z = Variable(torch.randn(mb_size, Z_dim))
G_sample = G(z)

D_fake = D(G_sample)

G_loss = F.binary_cross_entropy(D_fake, ones_label.view(-1,1))

G_loss.backward()
G_solver.step()

reset_grad()

#### Exercises

1. Notice that we re-sampled the noise before the Generator forward-backward step. Is this necessary?

2. Notice that for the discriminator we summed `loss_real` and `loss_fake` before performing the backward pass and updated the parameters based on the gradient of their sum. What would happen if we did a backward pass and update for each of them separately?

## Putting it together in a learning loop

Alright, let's put these steps together and see how it performs.

In [None]:
for it in range(100000):
    # Sample data
    z = Variable(torch.randn(mb_size, Z_dim))
    
    X, _ = next(iter(train_loader))
    X = Variable(X.view(-1, X_dim))

    # Dicriminator forward-loss-backward-update
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    D_loss_real = F.binary_cross_entropy(D_real, ones_label.view(-1, 1))
    D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label.view(-1, 1))
    D_loss = D_loss_real + D_loss_fake

    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = Variable(torch.randn(mb_size, Z_dim))
    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = F.binary_cross_entropy(D_fake, ones_label.view(-1, 1))

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:n_samples]

        n_gs = int(np.sqrt(n_samples))
        
        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(n_gs, n_gs)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

#### Exercises

1. Do you notice any trends in the trajectory of `D_loss` and `G_loss` over learning?

2. Wrap up the Generator and the Discriminator into their own classes, inheriting from `nn.Module`, similar to the way we did for the feedforward and convolutional neural networks. Does this make experimentation any more convenient?