# Generative Adversarial Networks
So far, all the models we have worked with except the VAE have been discriminative models. This means that they are simply trying to predit something about our existing dataset. Sometimes, we would not like to discriminate but generate new examples of the data as in the case of video or image generation. Technically, this is the problem of modelling a probability distribution which we have samples of.

One approach which implicitly models the distribution, the work of Ian Goodfellow, has be enjoying great success - often producing images indistinguishable from the examples on which it was trained. GANs take a game-theoretic approach, pitting two networks against eachother - the discriminator and the generator. The job of the generator is to produce images which are indistinguishable from the training set from latent variables while the job of the discriminator is to catch out the generator and discriminate real data from generated data. Initially they will both be terrible at their jobs but as the discriminator gets better, the generator is forced to get better to fool it and vice-versa. This loop continues until they are both excellent at their jobs and the generator can now be used to produce very realistic data points.

![](GAN.png)

An analogy often used to describe this is the detective and the forger. The generator is like a forger who is trying to produce paintings indistinguishable from other famous paintings by an artists while the discriminator is like a detective who is trying to catch the forger out. As the detective gets better at catching the generator out, the generator is forced to improve to fool the detective.

## Implementation
We will be training a GAN on the fashion MNIST dataset so we will be able to produce images of items of clothing which look like they came from the original dataset.

We begin by importing the appropriate libraries.

In [None]:
%matplotlib notebook
import torch
from torch.autograd import Variable
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
import matplotlib.pyplot as plt    

We load our dataset into a pytorch dataloader which we will use later to produce random batches of samples from our dataset.

In [None]:
batch_size = 100

train_data = datasets.FashionMNIST(root='fashiondata/',
                                 transform=transforms.ToTensor(),
                                 train=True,
                                 download=True
                                 )

train_samples = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True
                                            )

Define the NN model we will be using for our discriminator. It takes in the 28x28 image and performs convolutions followed by one fully connected layer to output the probability of the data point being real and not generated.

In [None]:
class discriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1) #1x28x28-> 64x14x14
        self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) #64x14x14-> 128x7x7
        self.dense1 = torch.nn.Linear(128*7*7, 1)

        self.bn1 = torch.nn.BatchNorm2d(64)
        self.bn2 = torch.nn.BatchNorm2d(128)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x))).view(-1, 128*7*7)
        x = F.sigmoid(self.dense1(x))
        return x

Define the NN model for the generator. This takes in a latent vector of size 128 and performs fully connected layers followed by upconvolution to output us a 28x28 image.

In [None]:
class generator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = torch.nn.Linear(128, 256)
        self.dense2 = torch.nn.Linear(256, 1024)
        self.dense3 = torch.nn.Linear(1024, 128*7*7)
        self.uconv1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) #128x7x7 -> 64x14x14
        self.uconv2 = torch.nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1) #64x14x14 -> 1x28x28

        self.bn1 = torch.nn.BatchNorm1d(256)
        self.bn2 = torch.nn.BatchNorm1d(1024)
        self.bn3 = torch.nn.BatchNorm1d(128*7*7)
        self.bn4 = torch.nn.BatchNorm2d(64)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.dense1(x)))
        x = F.relu(self.bn2(self.dense2(x)))
        x = F.relu(self.bn3(self.dense3(x))).view(-1, 128, 7, 7)
        x = F.relu(self.bn4(self.uconv1(x)))
        x = F.sigmoid(self.uconv2(x))
        return x

We instantiate our models from the classes we define and their optimizers

In [None]:
#instantiate model
d = discriminator()
g = generator()

#training hyperparameters
no_epochs = 100
dlr = 0.0003
glr = 0.0003

d_optimizer = torch.optim.Adam(d.parameters(), lr=dlr)
g_optimizer = torch.optim.Adam(g.parameters(), lr=glr)

dcosts = []
gcosts = []
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_xlabel('Epoch')
ax.set_ylabel('Cost')
ax.set_xlim(0, no_epochs)
plt.show()

We define the training loop. For every batch of training data that we look at, we get the generator to produce an equally sized batch of generated images. We then get the discriminator to make predictions on both sets and calulate the cost for both networks before calculating the gradients and training each one in turn.

In [None]:
def train(no_epochs, glr, dlr):
    for epoch in range(no_epochs):
        epochdcost = 0
        epochgcost = 0

        #iteratre over mini-batches
        for k, (real_images, _ ) in enumerate(train_samples):
            real_images = Variable(real_images) #real images from training set

            z = Variable(torch.randn(batch_size, 128)) #generate random latent variable to generate images
            generated_images = g.forward(z) #generate images

            gen_pred = d.forward(generated_images) #prediction of generator on generated batch
            real_pred = d.forward(real_images) #prediction of generator on real batch

            dcost = -torch.sum(torch.log(real_pred) + torch.log(1-gen_pred))/batch_size #cost of discriminator
            gcost = -torch.sum(torch.log(gen_pred))/batch_size #cost of generator
            
            #train discriminator
            d_optimizer.zero_grad()
            dcost.backward(retain_graph=True) #retain the computation graph so we can train generator after
            d_optimizer.step()
            
            #train generator
            g_optimizer.zero_grad()
            gcost.backward()
            g_optimizer.step()

            epochdcost += dcost.data[0]
            epochgcost += gcost.data[0]
            
            #give us an example of a generated image after every 10000 images produced
            if k*batch_size%10000 ==0:
                g.eval() #put in evaluation mode
                noise_input = Variable(torch.randn(1, 128))
                generated_image = g.forward(noise_input)

                plt.figure(figsize=(1, 1))
                plt.imshow(generated_image.data[0][0], cmap='gray_r')
                plt.show()
                g.train() #put back into training mode


            epochdcost /= 60000/batch_size
            epochgcost /= 60000/batch_size

            print('Epoch: ', epoch)
            print('Disciminator cost: ', epochdcost)
            print('Generator cost: ', epochgcost)
        
        #plot costs
        
        dcosts.append(epochdcost)
        gcosts.append(epochgcost)

        ax.plot(dcosts, 'b')
        ax.plot(gcosts, 'r')

        fig.canvas.draw()

train(no_epochs, glr, dlr)