# Generative Adversarial Networks (GANs)

In what follows, we explain how to implement a GAN in Keras. Our specific implementation will be a deep convolutional GAN, or DCGAN: a GAN where the generator and discriminator are deep convnets. In particular, it uses a Conv2DTranspose layer for image upsampling in the generator.

We will train our GAN on images from CIFAR10, a dataset of 50,000 32x32 RGB images belong to 10 classes (5,000 images per class). We will only use images belonging to the class "frog".

Schematically, our GAN looks like this:
- A `generator`  network maps vectors of shape `(latent_dim,)` to images of shape (32,32,3);
- A `discriminator` network maps images of shape (32, 32, 3) to a binary score estimating the probability that the image is real;
- A `gan` network chains the generator and the discriminator together;
- We train the discriminator using examples of real and fake images along with "real"/"fake" labels, as we would train any regular image classification model;
- To train the generator, we use the gradients of the generator's weights with regard to the loss of the `gan` model. This means that, at every step, we move the weights of the generator in a direction that will make the discriminator more likely to classify as "real" the images decoded by the generator. I.e. we train the generator to fool the discriminator.

** Import Keras. **

In [None]:
import keras
keras.__version__

### Some tricks

Training GANs and tuning GAN implementations is notoriously difficult. There are a number of known "tricks" that one should keep in mind. Like most things in deep learning, it is more alchemy than science: these tricks are really just heuristics, not theory-backed guidelines. They are backed by some level of intuitive understanding of the phenomenon at hand, and they are known to work well empirically, albeit not necessarily in every context.

Here are a few of the tricks that we leverage in our own implementation of a GAN generator and discriminator below. It is not an exhaustive list of GAN-related tricks; you will find many more across the GAN literature.

- We use `tanh` as the last activation in the generator, instead of `sigmoid`, which would be more commonly found in other types of models.
- We sample points from the latent space using a normal distribution (Gaussian distribution), not a uniform distribution.
- Stochasticity is good to induce robustness. Since GAN training results in a dynamic equilibrium, GANs are likely to get "stuck" in all sorts of ways. Introducing randomness during training helps prevent this. We introduce randomness in two ways: 1) we use dropout in the discriminator, 2) we add some random noise to the labels for the discriminator.
- Sparse gradients can hinder GAN training. In deep learning, sparsity is often a desirable property, but not in GANs. There are two things that can induce gradient sparsity: 1) max pooling operations, 2) ReLU activations. Instead of max pooling, we recommend using strided convolutions for downsampling, and we recommend using a `LeakyReLU` layer instead of a ReLU activation. It is similar to ReLU but it relaxes sparsity constraints by allowing small negative activation values.
- In generated images, it is common to see "checkerboard artifacts" caused by unequal coverage of the pixel space in the generator. To fix this, we use a kernel size that is divisible by the stride size, whenever we use a strided `Conv2DTranpose` or `Conv2D` in both the generator and discriminator.

### The Generator

First, we develop a `generator` model, which turns a vector (from the latent space -- during training it will sampled at random) into a candidate image. One of the many issues that commonly arise with GANs is that the generator gets stuck with generated images that look like noise. A possible solution is to use dropout on both the discriminator and generator.

** Create the generator as described here: **

- Dense layer of shape 128x16x16 with ReLU activation function;
- Reshape the input into a 16x16 128-channels feature map;
- Convolution layer of size 256 and filter 5x5, with ReLU activation function;
- Upsample to 32x32;
- Convolution layer of size 256 and filter 5x5, with ReLU activation function;
- Convolution layer of size 256 and filter 5x5, with ReLU activation function;
- Convolution layer of size `channels`, with filter 7x7 with `tanh` activation function;

In [None]:
import numpy as np
import keras


from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Flatten, Reshape
from keras.layers import Convolution2D, MaxPooling2D, Conv2DTranspose

latent_dim = 32
height = 32
width = 32
channels = 3

generator = Sequential()
generator.add(Dense(128 * 16 * 16, activation='relu', input_shape=(latent_dim,)))
generator.add(Reshape((16, 16,128))) # 128))
generator.add(Convolution2D(256, 5, padding='same'))
generator.add(Activation('relu'))
generator.add(Conv2DTranspose(256,4,strides=2, padding='same'))
generator.add(Activation('relu'))   
generator.add(Convolution2D(256, 5, padding='same'))
generator.add(Activation('relu'))
generator.add(Convolution2D(256, 5, padding='same'))
generator.add(Activation('relu'))
generator.add(Convolution2D(channels, 7, padding='same'))
generator.add(Activation('tanh'))
generator.summary()

### The Discriminator

Then, we develop a `discriminator` model, that takes as input a candidate image (real or synthetic) and classifies it into one of two classes, either "generated image" or "real image that comes from the training set".

** Create the discriminator as described here: **

- Convolution layer of size 128, 3x3 filters and input_shape=(height, width, channels) and ReLU activation function;
- Convolution layer of size 128 and filter 4x4, strides 2 with ReLU activation function;
- Convolution layer of size 128 and filter 4x4, strides 2 with ReLU activation function;
- Convolution layer of size 128 and filter 4x4, strides 2 with ReLU activation function;
- Flatten layer;
- Dropout with rate 0.4
- Fully Connected layer with `sigmoid` activation function;

In [None]:
discriminator = Sequential()
discriminator.add(Convolution2D(128, 3, input_shape=(height, width, channels)))
discriminator.add(Activation('relu'))
discriminator.add(Convolution2D(128, 4,strides=2))
discriminator.add(Activation('relu'))
discriminator.add(Convolution2D(128, 4,strides=2))
discriminator.add(Activation('relu'))
discriminator.add(Convolution2D(128, 4,strides=2))
discriminator.add(Activation('relu'))
discriminator.add(Flatten())
discriminator.add(Dropout(0.4))
discriminator.add(Dense(1,activation='sigmoid'))
discriminator.summary()

** Compile the discriminator with an appropriate optimizer and loss. **

In [None]:
discriminator_optimizer = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy')

### The Aversarial Network

Finally, we setup the GAN, which chains the generator and the discriminator. This is the model that, when trained, will move the generator in a direction that improves its ability to fool the discriminator. This model turns latent space points into a classification decision, "fake" or "real", and it is meant to be trained with labels that are always "these are real images". So training gan will updates the weights of `generator` in a way that makes `discriminator` more likely to predict "real" when looking at fake images. Very importantly, we set the discriminator to be frozen during training (non-trainable): its weights will not be updated when training `gan`. If the discriminator weights could be updated during this process, then we would be training the discriminator to always predict "real", which is not what we want!

#### Set the discriminatorweights to non-trainable (will only apply to the `gan` model). Create an input for the gan and an output, call the gan input as `gan_input` and the output as `gan_output` 
- for the `gan_input` you can use Keras method `Input` and pass to it the correct shape;
- the `gan_output` is given instead by the discriminator output when applying the generator with `gan_input`;

Finally create a gan model, use the Keras `Model` method and call the model `gan`.

In [None]:
# Set discriminator weights to non-trainable
# (will only apply to the `gan` model)
discriminator.trainable = False

gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)

** Compile your gan Model with an appropriate optimizer and loss. **

In [None]:
gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

### How to Train a DCGAN

Now we can start training. To recapitulate, this is schematically what the training loop looks like:

for each epoch:
    * Draw random points in the latent space (random noise).
    * Generate images with `generator` using this random noise.
    * Mix the generated images with real ones.
    * Train `discriminator` using these mixed images, with corresponding targets, either "real" (for the real images) or "fake" (for the generated images).
    * Draw new random points in the latent space.
    * Train `gan` using these random vectors, with targets that all say "these are real images". This will update the weights of the generator (only, since discriminator is frozen inside `gan`) to move them towards getting the discriminator to predict "these are real images" for generated images, i.e. this trains the generator to fool the discriminator.
Let's implement it:

** Fill in the #TO DOs. **

In [None]:
import os
from keras.preprocessing import image

# Load CIFAR10 data
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data() #TO DO

# Select frog images (class 6)
x_train = x_train[y_train.flatten() == 6]

# Normalize data - #TO DO
x_train = x_train.reshape(
    (x_train.shape[0],) + (height, width, channels)).astype('float32') / 255.

# TO DO: set iterations to 10000 if a GPU is available otherwise just set to 100, batch size equal to 20. 
# Moreover, set the path of the folder where you want to save the images.
iterations = 10000
batch_size = 20
save_dir = './'

# Start training loop
start = 0
for step in range(iterations):
    # Sample random points in the latent space
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))

    # Decode them to fake images
    generated_images = generator.predict(random_latent_vectors)

    # Combine them with real images
    stop = start + batch_size
    real_images = x_train[start: stop]
    combined_images = np.concatenate([generated_images, real_images])

    # Assemble labels discriminating real from fake images
    labels = np.concatenate([np.ones((batch_size, 1)),
                             np.zeros((batch_size, 1))])
    # Add random noise to the labels - important trick!
    labels += 0.05 * np.random.random(labels.shape)

    # Train the discriminator
    d_loss = discriminator.train_on_batch(combined_images, labels)

    # sample random points in the latent space
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))

    # Assemble labels that say "all real images"
    misleading_targets = np.zeros((batch_size, 1))

    # Train the generator (via the gan model,
    # where the discriminator weights are frozen)
    a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
    
    start += batch_size
    if start > len(x_train) - batch_size:
      start = 0

    # Occasionally save / plot
    if step % 100 == 0:
        # Save model weights
        gan.save_weights('gan.h5')

        # Print metrics
        print('discriminator loss at step %s: %s' % (step, d_loss))
        print('adversarial loss at step %s: %s' % (step, a_loss))

        # Save one generated image
        img = image.array_to_img(generated_images[0] * 255., scale=False)
        img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))

        # Save one real image, for comparison
        img = image.array_to_img(real_images[0] * 255., scale=False)
        img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))

Let's display some images now.

In [None]:
import matplotlib.pyplot as plt

# Sample random points in the latent space
random_latent_vectors = np.random.normal(size=(10, latent_dim))

# Decode them to fake images
generated_images = generator.predict(random_latent_vectors)

for i in range(generated_images.shape[0]):
    img = image.array_to_img(generated_images[i] * 255., scale=False)
    plt.figure()
    plt.imshow(img)
    
plt.show()

Froggy with some pixellated artifacts.