# Generative Adversarial Network

#### Select Processing Device

In [None]:
import os
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICE'] = '0' # leave empty to run on CPUs only.

#### Load Dependencies

In [None]:
import numpy as np

import keras
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Conv2D, BatchNormalization, Dropout, Flatten
from keras.layers import Activation, Reshape, Conv2DTranspose, UpSampling2D
from keras.optimizers import RMSprop

from keras_contrib.layers.advanced_activations import SineReLU

import pandas as pd
from matplotlib import pyplot as plt
%matplotlib inline

print(keras.__version__)

#### Load Data

In [None]:
input_images = '../../quickdraw_data/apple.npy'
data = np.load(input_images)

print(data.shape)
print(data[4242])

#### Preprocess Data

In [None]:
# It needs to be normalised, since the current range is from 0 to 255.
data = data / 255

# Quickdraw images are 28x28 and greyscale. To use it with CNNs, we have to reshape the images to be 28x28x1
img_w, img_h, channels = 28, 28, 1
data = np.reshape(data, (data.shape[0], img_w, img_h, channels))

print(data.shape)

#### Plot Image

In [None]:
plt.imshow(data[4242, :, :, 0], cmap="Greys")

#### Build the Discriminator Network¶

In [None]:
# Those parameters could be parameterised.
def discriminator_builder(filters=64, kernel=5, drop=0.4):
    inputs = Input((img_w, img_h, channels))
    
    X = Conv2D(filters, kernel_size=kernel, strides=2, padding='same')(inputs)
    X = SineReLU(0.0025)(X)
    X = Dropout(drop)(X)
    
    X = Conv2D(filters * 2, kernel_size=kernel, strides=2, padding='same')(X)
    X = SineReLU(0.0025)(X)
    X = Dropout(drop)(X)

    X = Conv2D(filters * 4, kernel_size=kernel, strides=2, padding='same')(X)
    X = SineReLU(0.0025)(X)
    X = Dropout(drop)(X)

    X = Conv2D(filters * 8, kernel_size=kernel, strides=1, padding='same')(X)
    X = SineReLU(0.0025)(X)
    X = Dropout(drop)(X)
    
    X = Flatten()(X)
    X = Dropout(drop)(X)
    
    output = Dense(1, activation='sigmoid')(X)
    
    model = Model(inputs=inputs, outputs=output)
    model.summary()
    
    return model

In [None]:
discriminator = discriminator_builder()

#### Compile Model

In [None]:
# We need to compile the discriminator model because it's going to validate the input images.

d_model = Sequential()
d_model.add(discriminator)

d_model.compile(loss='binary_crossentropy',
                     optimizer=RMSprop(lr=0.0008, decay=6e-8, clipvalue=1.0),
                     metrics=['accuracy'])

#### Build the Generator Network

In [None]:
def generator_builder(latent_space=100, filters=64, kernel=5, drop=0.4):
    inputs = Input((latent_space,))
    
    # The 64 here will is needed because it will represent the amount of filters in the D-Conv Layer.
    X = Dense(7 * 7 * 64)(inputs)
    X = BatchNormalization(momentum=0.9)(X)
    X = ReLUs(epsilon=0.0025)(X)
    X = Reshape((7, 7, 64))(X)
    X = Dropout(drop)(X)
    
    # De-Convolutional Layer
    X = UpSampling2D()(X)
    # Activation is set to None because we want to control when things happen. In that case, we want
    # BatchNormalization to happen before the activation.
    X = Conv2DTranspose(int(filters / 2), kernel, padding='same')(X)
    X = BatchNormalization(momentum=0.9)(X)
    X = ReLUs(epsilon=0.0025)(X)
    
    X = UpSampling2D()(X)
    X = Conv2DTranspose(int(filters / 4), kernel, padding='same')(X)
    X = BatchNormalization(momentum=0.9)(X)
    X = ReLUs(epsilon=0.0025)(X)

    X = Conv2DTranspose(int(filters / 8), kernel, padding='same')(X)
    X = BatchNormalization(momentum=0.9)(X)
    X = ReLUs(epsilon=0.0025)(X)
    
    X = Conv2DTranspose(int(filters / 16), kernel, padding='same')(X)
    X = BatchNormalization(momentum=0.9)(X)
    X = ReLUs(epsilon=0.0025)(X)
    
    # Cnovolutional Layer
    # 1 filter convolution layer because it will represent a full 28x28x1 image.
    # Using 'sigmoid' here because we want this 28x28x1 image to have pixels between 0 and 1.
    output = Conv2D(1, kernel, padding='same', activation='sigmoid')(X)
    
    model = Model(inputs=inputs, outputs=output)
    model.summary()
    
    return model

In [None]:
# We don't need to compile the generator model as it is part of the adversarial model.
generator = generator_builder()

#### Build the Adversarial Network 

In [None]:
def adversarial_builder(latent_space=100):
    inputs = Input((latent_space,))
    
    X = generator(inputs)
    output = discriminator(X)
    
    model = Model(inputs=inputs, outputs=output)
    model.summary()
    
    return model

In [None]:
adversarial_model = adversarial_builder()

#### Compile the Model

In [None]:
# Optimizer arguments are half of what we have in the Discriminator Network.
adversarial_model.compile(loss='binary_crossentropy',
             optimizer=RMSprop(lr=0.0004, decay=3e-8, clipvalue=1.0),
             metrics=['accuracy'])

#### Train

In [None]:
# This function is used to enable/disable trainability in the discriminator. For instance, when runing only
# The discriminator, based on a real input, it should be trainable (i.e. the weights should change). However,
# when the input is coming from the generator network, we do not want to train the Discrimintor. So, the weigths
# should be frozen.
def set_trainability(model, should_train=False):
    model.trainable = should_train
    for l in model.layers:
        l.trainable = should_train
    

In [None]:
def train(latent_space=100, epochs=2000, batch_size=128):
    sample_size = 16
    
    d_metrics = []
    a_metrics = []
    
    running_d_loss = 0
    running_d_acc = 0
    running_a_loss = 0
    running_a_acc = 0
    
    for i in range(epochs):
        if i % 50 == 0:
            print('Epoch --> %s...' % str(i))
        
        # Get random 128 (batch_size) real images
        random_real_imgs = data[np.random.choice(data.shape[0], batch_size, replace=False)]
        fake_imgs = generator.predict_on_batch(np.random.uniform(-1.0, 1.0, size=[batch_size, latent_space]))
        
        X = np.concatenate((random_real_imgs, fake_imgs))
        y = np.ones([2 * batch_size, 1])
        # Make the second half of the y vector 0, because they are all fake images.
        y[batch_size:, :] = 0

        set_trainability(d_model, should_train=True)
        
        metrics = d_model.train_on_batch(X, y)
        d_metrics.append(metrics)
        running_d_loss += d_metrics[-1][0]
        running_d_acc += d_metrics[-1][1]
        
        set_trainability(d_model)
        
        A_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_space])
        y_noise = np.ones([batch_size, 1])
        
        a_metrics.append(adversarial_model.train_on_batch(A_noise, y_noise))
        running_a_loss += a_metrics[-1][0]
        running_a_acc += a_metrics[-1][1]
        
        if (i + 1) % 100 == 0:

            print('Epoch #{}'.format(i + 1))
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, running_d_loss / i, running_d_acc / i)
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, running_a_loss / i, running_a_acc / i)
            print(log_mesg)

            noise = np.random.uniform(-1.0, 1.0, size=[sample_size, latent_space])
            gen_imgs = generator.predict_on_batch(noise)

            plt.figure(figsize=(5, 5))

            for k in range(gen_imgs.shape[0]):
                plt.subplot(4, 4, k + 1)
                plt.imshow(gen_imgs[k, :, :, 0], cmap='gray')
                plt.axis('off')
                
            plt.tight_layout()
            plt.show()
    
    return a_metrics, d_metrics
        

In [None]:
a_metrics_complete, d_metrics_complete = train(epochs=3000)

In [None]:
ax = pd.DataFrame(
    {
        'Generator': [metric[0] for metric in a_metrics_complete],
        'Discriminator': [metric[0] for metric in d_metrics_complete],
    }
).plot(title='Training Loss', logy=True)
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")

In [None]:
ax = pd.DataFrame(
    {
        'Generator': [metric[1] for metric in a_metrics_complete],
        'Discriminator': [metric[1] for metric in d_metrics_complete],
    }
).plot(title='Training Accuracy')
ax.set_xlabel("Epochs")
ax.set_ylabel("Accuracy")