# Generating MNIST handwritten digits with GAN

Generative Adversarial Network (GAN) is a class of machine learning frameworks introduced by Ian Goodfellow et al. in 2014. GANs are used for estimating generative models via an adversary process, in which we simultaneously train two models: A generative model G that capture the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G. The training procedure for G is to maximize the probability of D making a mistake.

GANs gained rapidly in popularity and have been used in a lot of applications such as:
- Learning to generate realistic images given exemplary images
- Learning to generate realistic music given exemplary recordings
- Learning to generate realistic text given exemplary corpus

We will focus here on implementing a Deep Convolutional Generative Adversarial Networks (DCGANs) for generating handwritten digits. The generator G will learn how to generate new plausible handwritten digits between 0 and 9, and the discriminator will estimate if images are from the dataset ("Real" images) or if they are new ("Fake" images).

<img src="https://cdn-media-1.freecodecamp.org/images/m41LtQVUf3uk5IOYlHLpPazxI3pWDwG8VEvU" alt="Alt text that describes the graphic" title="Genrative Adversarial Network framework" />

https://medium.freecodecamp.org/an-intuitive-introduction-to-generative-adversarial-networks-gans-7a2264a81394

In the case of image generation, the discriminator is a Convolutional Neural Network (CNN) that classify whether an image is real or generated, and the generator is made of inverse convolutional layers to transform a random input to an image.

### Summary
* [1. Data Preprocessing](#chapter1)
* [2. GAN model](#chapter2)
    * [a. The Discriminator model](#section_2_1)
    * [b. The Generator model](#section_2_2)
* [3. Training the model](#chapter3)
* [4. Training the model](#chapter4)

* [Useful Links](#links)


In [2]:
# Import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose 
from tensorflow.keras.layers import BatchNormalization, Flatten
from tensorflow.keras.layers import LeakyReLU, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model

from sklearn.utils import shuffle

from timeit import default_timer as timer   

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

ModuleNotFoundError: No module named 'tensorflow'

## 1. Data preprocessing <a class="anchor" id="chapter1"></a>

The MNIST dataset is composed of 70000 28x28 grayscale images of handwritten digits between 0 and 9, along with their respective label.

The training set and the test set have respectively 60000 and 10000 images.

Let's load the dataset using Keras mnist.load_data() function.

In [3]:
# Load the mnist dataset
from tensorflow.keras.datasets.mnist import load_data

(x_train, y_train), (x_test, y_test) = load_data()
print('Train shape:', x_train.shape, y_train.shape)
print('Test shape:', x_test.shape, y_test.shape)

Train shape: (60000, 28, 28) (60000,)
Test shape: (10000, 28, 28) (10000,)


We need to reshape images because images are 2D arrays and convolutional neural networks expect 3D arrays of images as input of the following shape: [width, heigth, channels].

Here we only have one greyscale channel (if images where colored, there would be 3 channels for Red, Green, and Blue).

In [4]:
# Reshape images
x_train = x_train.reshape((x_train.shape[0],28,28,1))
x_test = x_test.reshape((x_test.shape[0],28,28,1))

We must also rescale pixel values from the [0,255] range to the normalized range.
Then, we plot several digit images from the training set.

In [None]:
# Normalize
x_train = x_train.astype('float32')
x_train /= 255
x_test = x_test.astype('float32')
x_test /= 255

x_plot = np.squeeze(x_train)
# Plot some image examples
plt.figure(figsize=(10,10))
i = 0
for i in range (25):
    plt.subplot(5,5,i+1)
    plt.axis('off')
    plt.grid(False)
    plt.imshow(x_train[i], cmap=plt.cm.binary)
    i += 1
plt.show()

## 2. GAN model <a class="anchor" id="chapter2"></a>

### a. The Discriminator model <a class="anchor" id="section_2_1"></a>
First, we implement the discriminator model. It is a binary classifier that takes as input images and tells if they are real or fake. We use two convolutional layers with 64 filters each, a kernel size of (3,3) and a (2,2) strides. 

The convolutional layers are both followed by a dropout layer. Then, we add a fully connected layer of 64 nodes, followed by the output layer.

We use LeakyReLU activation function for hidden layers, and use a sigmoid function at the output layer to get the probability that a sample is real or fake.

The model is trained using bach normalization (except for the input layer), Adam optimizer and the binary crossentropy loss function as we are in a binary classification setting.

Note: Directly applying batchnorm to all layers results in sample oscillation and model instability.

Those parameters were chosen by following Radford et al. guidelines for implementing Deep Convolutional Generative Adversarial Networks (DCGANs). The link to the paper is provided below in [Useful Links](#links).

In [None]:
def get_discriminator(image_shape=(28,28,1)):
    model = Sequential()
    
    model.add(Conv2D(64, kernel_size=(3,3), strides=(2,2), padding='same', input_shape=image_shape))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(Conv2D(128, kernel_size=(3,3), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(epsilon=1e-5, momentum=0.9))
    
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))

    model.compile(loss='binary_crossentropy',
                  optimizer=Adam(learning_rate=0.0002,beta_1=0.5),
                  metrics=['accuracy'])
    return model

In [None]:
discriminator = get_discriminator()
discriminator.summary()

### b. The Generator model <a class="anchor" id="section_2_2"></a>

As explained earlier, the generator aim is to create new images of  handwritten digits that the discriminator won't be able to differenciate from real handwritten digits.

The generator takes as input a random vector drawn from a normal distribution, and upsample several times to obtain the output image. Here we chose to have a vector of size 100.

The first layer of the generator is a Dense layer that has a number of nodes that will enable us by reshaping to obtain a low resolution of the output image. For example to get an image that is one quarter the size of the output image, we use 7x7 = 49 nodes.

Now that we have a low resolution version of the output, we are going to upsample using transpose convolutional layers (Conv2DTranspose layers) with strides 2. Conv2DTranspose basically do the inverse of a normal Conv2D layer, and using strides of 2 will double the size of the image. Hence, by using 2 Conv2DTranspose layers with strides 2, we manage to quadruple the 7x7 image to obtain the 28*28 output image. We use model.summary() to verify that we upsample images properly. In particular we can see below that the output layer is of shape (28, 28, 1).

Once again, we follow recommandations from Radford et al. and use ReLU activation in generator for all layers except for the output, which uses Tanh, and Batch Normalization after each Conv2DTranspose hidden layers.

Note: As for the discriminator, directly applying batchnorm to all layers results in sample oscillation and model instability.

In [None]:
def get_generator(random_vect_size):
    model = Sequential()
    
    # Base image of shape 7x7
    model.add(Dense(512*7*7, activation='relu', input_dim=random_vect_size))
    model.add(Reshape((7, 7, 512)))
    
    # Upsample to 14x14
    model.add(Conv2DTranspose(256, (4,4), strides=(2,2), activation='relu', padding='same'))
    model.add(BatchNormalization(epsilon=1e-5, momentum=0.9))
    
    # Upsample to 28x28
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), activation='relu', padding='same'))
    model.add(BatchNormalization(epsilon=1e-5, momentum=0.9))

    # Output layer
    model.add(Conv2DTranspose(1, (7,7), activation='tanh', padding='same'))
    
    return model

One thing to notice here is that we do not compile the generator model yet. This will be done in [3. Training the model](#chapter2). We will explain how to use the discriminator to train the generator.

In [None]:
random_vect_size = 100
generator = get_generator(random_vect_size)
generator.summary()

Let's define the function that will help us generating the random vector, and the function that will use the generator model to generate fake images.

In [5]:
def get_random_vect(size, nb_samples):
    vect = np.random.randn(nb_samples, size)
    return vect

In [6]:
def get_fake_images(generator, random_vect_size, nb_samples):
    rand_vect = get_random_vect(random_vect_size, nb_samples)
    x_gen = generator.predict(rand_vect)
    
    # Create the label vector for these fake images
    y_gen = np.zeros(nb_samples)
    
    return x_gen, y_gen

Let's generate a few images and plot the result to verify that our implementation works well. We are supposed to obtain randomly generated greyscale images as our generator is not trained yet.

In [None]:
random_vect_size = 100
nb_samples = 25

generator = get_generator(random_vect_size)
x_gen, _ = get_fake_images(generator, random_vect_size, nb_samples)
x_gen = np.squeeze(x_gen)

# Plot some image examples
plt.figure(figsize=(10,10))
i = 0
for i in range (25):
    plt.subplot(5,5,i+1)
    plt.axis('off')
    plt.grid(False)
    plt.imshow(x_gen[i], cmap=plt.cm.binary)
    i += 1
plt.show()

## 3. Training the model <a class="anchor" id="chapter3"></a>

Training the GAN model will work as follow.

We will use batches, and for each batch, we will first update the discriminator. Then, we will update the generator by using the discriminator to evaluate the performance of the generator.

The idea is that the more the discriminator detects fake images, the more the generator is updated. At some point, the generator will perform well and the discriminator will no longer be able to distinguish fake images from true images.

First, to train the generator, we will create a third model that will encapsulate both the generator and the discriminator as we need to use the discriminator as a measure of how well the generator is performing. 

The discriminator weights are not trained in this model as for each batch we first train the discriminator and then the generator.

In [None]:
def get_GAN(generator, discriminator):
    discriminator.trainable = False
    
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    
    model.compile(loss='binary_crossentropy',
                  optimizer=Adam(learning_rate=0.0002,beta_1=0.5),
                  metrics=['accuracy'])
    return model

In [None]:
gan = get_GAN(generator, discriminator)
gan.summary()

In [7]:
def predict_and_plot(generator, random_vect_size):
    x_gen, _ = get_fake_images(generator, random_vect_size, 100)
    x_gen = np.squeeze(x_gen)

    # Plot some image examples
    plt.figure(figsize=(10,10))
    i = 0
    for i in range (100):
        plt.subplot(10,10,i+1)
        plt.axis('off')
        plt.grid(False)
        plt.imshow(x_gen[i], cmap=plt.cm.binary)
        i += 1
    plt.show()

We define the training procedure manually by iterating over an arbitrary number of epochs, and an arbitrary batch size. For each epoch and each batch, we generate real and fake samples, and we train the discriminator. Then, we generate random inputs for the generator with inverted labels. This is one of the key trick! We want the generator to beat the discrimintor. Thus, if the discriminator predicts that the generated images are real, then we don't update the generator, and conversely, if the discriminator predicts that generated images are fake, then we need to update the generator.

In [None]:
def train(generator, discriminator, gan, x_train, random_vect_size, epochs=3, batch=128):
    # Number of batches per epoch
    ba_per_ep = int(x_train.shape[0]/batch)

    for i in range(epochs):
        x_train = shuffle(x_train)
        for j in range(ba_per_ep):
            # generate real samples
            x_real = x_train[batch*j: batch*(j+1)]
            y_real = np.ones(x_real.shape[0])
            # generate fake samples
            x_fake, y_fake = get_fake_images(generator, random_vect_size, batch)            
            # Stack real and fake samples, and shuffle
            D_x_train, D_y_train = np.vstack((x_real, x_fake)), np.concatenate((y_real, y_fake))
            # Update the discriminator
            D_loss, _ = discriminator.train_on_batch(D_x_train, D_y_train)
            
            # Create the random inputs for the generator
            x_gen = get_random_vect(random_vect_size, batch)
            # Create inverted labels vector for these fake images that will be used to train the generator
            y_gen = np.ones(batch)
            # Update the generator
            G_loss, _ = gan.train_on_batch(x_gen, y_gen)
            G_loss, _ = gan.train_on_batch(x_gen, y_gen) #update twice to avoid the rapid convergence of the discriminator

            print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, ba_per_ep, D_loss, G_loss))
        print('Epoch' + str(i) + ': ' + 'Gen loss ' + str(G_loss) + ', Dis loss ' + str(D_loss))
        # save the generator model tile file
        filename = 'generator_model_%03d.h5' % (i + 1)
        generator.save(filename)
        predict_and_plot(load_model(filename), random_vect_size)

In [None]:
random_vect_size = 100

start = timer()
# train the model
train(generator, discriminator, gan, x_train, random_vect_size, epochs=50,  batch=256)
print("Training time:", timer() - start)

## 4. Results <a class="anchor" id="chapter4"></a>

Now that we trained the model and save it at each epoch, we can generate handwritten digits.

In [1]:
final_generator = load_model('generator_model_020.h5')

predict_and_plot(final_generator, random_vect_size)

NameError: name 'load_model' is not defined

## Useful Links <a class="anchor" id="links"></a>

**GAN original paper:**

https://arxiv.org/abs/1406.2661

GOODFELLOW, Ian J., POUGET-ABADIE, Jean, MIRZA, Mehdi, et al. Generative adversarial networks. arXiv preprint arXiv:1406.2661, 2014.

**Tutorial for implementing a GAN for Generating MNIST Handwritten Digits:**

https://machinelearningmastery.com/how-to-develop-a-generative-adversarial-network-for-an-mnist-handwritten-digits-from-scratch-in-keras/

**Deep Convolutional Generative Adversarial Networks (DCGANs) paper:**
https://arxiv.org/pdf/1511.06434.pdf

Radford, A., Metz, L., & Chintala, S. (2015). Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434.