# Generative Adversarial Networks 

In [1]:
__author__ = 'andrea munafo'

This notebook plays with GANs using the MNIST dataset.

Interesting references are reported at the end.

## GANs in brief

Suppose that we are interested in generating black and white square images of somthing (e.g. dogs) with a size of $nxn$ pixels.   
We can also reshape each image to be a vector of size $N=nxn$. This means that we can represent the image of a dog as a vector of size $N$.
This, of course, does not mean that all vectors of size $N$ represent dogs (once back to a square image) but we can say that the $N$ dimentional vectors that represent something that looks like a dog are distributed according to a very specific probability distribution over the entire $N$ dimensional vector space.
Some points of this space represent dogs, other might represent cats, etc.

The problem of generating a new image of dog is equivalent to the problem of generating a new vector according to the correct "dog probability distribution" [[1](https://towardsdatascience.com/understanding-generative-adversarial-networks-gans-cd6e4651a29)] over the $N$ dimensional vector space. 
This is the general problem of generating a random variable with respect to a specific probability distribution.

The problem then becomes that of using a neural network to approximate the target probability distribution.
This is equivalent to using the inverse transform sampling method using a neural network.

Our first problem when trying to generate our new image of dog is that the “dog probability distribution” over the N dimensional vector space is a very complex one and we don’t know how to directly generate complex random variables. However, as we know pretty well how to generate N uncorrelated uniform random variables, we could make use of the transform method. To do so, we need to express our N dimensional random variable as the result of a very complex function applied to a simple N dimensional random variable!
Here, we can emphasise the fact that finding the transform function is not as straightforward as just taking the closed-form inverse of the Cumulative Distribution Function (that we obviously don’t know) as we have done when describing the inverse transform method. The transform function can’t be explicitly expressed and, then, we have to learn it from data.

Then, the idea is to model the transform function using a neural network that takes as input a simple N dimensional uniform random variable and that returns as output another N dimensional random variable that should follow, after training, the right “dog probability distribution”.

To train this network we can use two methods. A direct one, where we compare the true and the generated probability distributions and then we backpropage the error. This is the idea behind Generative Matching Networks (GMNs).
In the indirect method we do not do a direct comparison but we add an additional layer (a discrimination task between true and generated samples) that somehow tries to enforce that the true and the generated distribution are as close as possible.
The indirect method is the one used by Generative Adversarial Networks.

So, in a GAN architecture, we have a discriminator, that takes samples of true and generated data and that try to classify them as well as possible, and a generator that is trained to fool the discriminator as much as possible.

## Implementing a simple GAN

In [9]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import os

import pylab as plt
import numpy as np

# from torch import nn

In [10]:
print(torch.__version__)

1.3.0


In [52]:
import pathlib

pathlib.Path("../results/08-generative-adversarial-networks").mkdir(parents=True, exist_ok=True)
pathlib.Path("../saved-mdls/08-generative-adversarial-networks").mkdir(parents=True, exist_ok=True)

In [63]:
num_epochs = 10
bs = 64
learning_rate = 1e-3

fake_img_size = 100

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [35]:
def toImg(x, mu=0.5, std=1):
    """Converts x to an image shape. It works for batches of inputs."""
    x = mu * (x + std)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

## Get the data

### Define some transforms to normalise the images  

In [36]:
ds_mean = 0.1307
ds_std = 0.3081

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((ds_mean,), (ds_std,)) # The first tuple (0.5, 0.5, 0.5) is the mean for all three channels and the second (0.5, 0.5, 0.5) is the standard deviation for all three channels.
])

In [37]:
train_ds = MNIST('./data', train=True, transform=img_transform, download=True)
valid_ds = MNIST('./data', train=False, transform=img_transform, download=True)

In [38]:
# plt.imshow(train_ds.data[1])

In [39]:
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=True)

## Define the model 

Let's create the two competing networks, the generator:

In [40]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.generator = nn.Sequential(
            nn.Linear(fake_img_size, 256),
            nn.ReLU(True),
            nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())

    def forward(self, x):
        x = self.generator(x)
        return x

And the discriminator:

In [41]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())

    def forward(self, x):
        x = self.discriminator(x)
        return x

And now we put everything together:

In [42]:
Dnet = Discriminator()
Gnet = Generator()

if torch.cuda.is_available():
    Dnet = Dnet.cuda()
    Gnet = Gnet.cuda()
    
loss_fn = nn.BCELoss() # Binary cross entropy loss


# We need to optimizers, one per network.
d_optimizer = torch.optim.Adam(Dnet.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(Gnet.parameters(), lr=learning_rate)

## Train the model

In [None]:
for epoch in range(num_epochs):
    for x, y in train_dl: # x is the real image, y is the real label
        img_nums = x.shape[0]
        
        # In this case, we want to discriminate which one is real and which is fake.
        # So we create the labels accordingly.
        real_label = Variable(torch.ones(img_nums)).to(device)
        fake_label = Variable(torch.zeros(img_nums)).to(device)

        # Train the discriminator network
        x = x.view(x.size(0), -1).to(device)
        
        # compute loss of real_img
        d_real_out = Dnet(x)
        d_loss_real = loss_fn(d_real_out, real_label) 
        
        # generate the image of a specific dimention (z_dimention).
        # We start from noise.
        z = Variable(torch.randn(img_nums, fake_img_size)).to(device)
        fake_img = Gnet(z)
        d_fake_out = Dnet(fake_img)
        d_loss_fake = loss_fn(d_fake_out, fake_label)
        
        real_scores = d_real_out  # closer to 1 means better
        fake_scores = d_fake_out  # closer to 0 means better

        # backprop
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()        
        
        # Train the generator network
        
        # compute loss of fake_img
        z = Variable(torch.randn(img_nums, fake_img_size)).to(device)
        fake_img = Gnet(z)
        output = Dnet(fake_img)
        g_loss = loss_fn(output, real_label) # not sure why I have real_label here!

        # backprop and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        
    if epoch == 0:
        print('Epoch [{}/{}], saving a sample of real images.'.format(epoch, num_epochs))
        real_images = toImg(x.cpu().data)
        save_image(real_images, '../results/08-generative-adversarial-networks/real-images.png')
        
    if (epoch + 1) % 100 == 0:
        print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
              'D real: {:.6f}, D fake: {:.6f}'.format(
              epoch, num_epochs, d_loss.data[0], g_loss.data[0],
              real_scores.data.mean(), fake_scores.data.mean()))



    fake_images = toImg(fake_img.cpu().data)
    save_image(fake_images, '../results/08-generative-adversarial-networks/fake-images-{}.png'.format(epoch + 1))

Epoch [0/10]. Saving a sample of real images.


In [None]:
# As suggested in https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save(Gnet, '../saved-mdls/08-generative-adversarial-networks/generator-{}e.pt'.format(epoch+1))
torch.save(Dnet, '../saved-mdls/08-generative-adversarial-networks/distriminator-

# References

1. [Understanding GANS](https://towardsdatascience.com/understanding-generative-adversarial-networks-gans-cd6e4651a29)
2. [Data science courses](https://www.youtube.com/channel/UCKJNzy_GuvX3SAg3ipaGa8A)

In [58]:
%debug

> [0;32m/Users/andreamunafo/opt/anaconda3/envs/number-five-dl/lib/python3.6/site-packages/torch/nn/functional.py[0m(2058)[0;36mbinary_cross_entropy[0;34m()[0m
[0;32m   2056 [0;31m    [0;32mif[0m [0minput[0m[0;34m.[0m[0mnumel[0m[0;34m([0m[0;34m)[0m [0;34m!=[0m [0mtarget[0m[0;34m.[0m[0mnumel[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   2057 [0;31m        raise ValueError("Target and input must have the same number of elements. target nelement ({}) "
[0m[0;32m-> 2058 [0;31m                         "!= input nelement ({})".format(target.numel(), input.numel()))
[0m[0;32m   2059 [0;31m[0;34m[0m[0m
[0m[0;32m   2060 [0;31m    [0;32mif[0m [0mweight[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> n
