# GAN that generates handwritten digits

A small Jupyter Notebook implementing a generative adversarial network (GAN) in Tensorflow that learns to generate images similar to the MNSIT dataset.


In [None]:
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# import the MNIST dataset
from tensorflow.examples.tutorials.mnist import input_data

# import the dataset
mnist = input_data.read_data_sets('MNIST_data')

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

# display the image
image = mnist.train.images[1].reshape((28, 28))
plt.imshow(image)

## Discriminator

The discriminator has the job to differentiate if an image is generated or a real sample from the dataset. 
Discriminator architecture:


1. **```Input```**: accepts 28px x 28px monochromatic images
2. **```1st convolution```**: filtersize 5x5, 32 features, ```SAME``` padding, ReLU Activation 
3. **```Average Pooling```**: 2x2 kernel, 2x2 stride
4. **```2nd convolution```**: fitersize 5x5, 64 features, ```SAME``` padding, ReLU Activation
5. **```Average Pooling```**: 2x2 kernel, 2x2 stride
6. **```1st fully-connected```**: 1024 neurons, ReLU activation
7. **```2nd fully-connected```**: $K$ neurons, sigmoid activation



In [None]:

def discriminator(image, reuseVariables):
    with tf.variable_scope("discriminator") as scope:
        if (reuseVariables):
            scope.reuse_variables()
        
        # 1st convolution
        c1_out = tf.layers.conv2d(
            inputs=image,
            filters=32,
            kernel_size=[5, 5],
            padding="same",
            activation=tf.nn.relu,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
            name="d_c1"
        )       
        
        # Average pooling
        p1_out = tf.layers.average_pooling2d(inputs=c1_out, pool_size=[2, 2], strides=[2,2], name="d_p1")
        
        # 2nd convolution
        c2_out = tf.layers.conv2d(
            inputs=p1_out,
            filters=64,
            kernel_size=[5, 5],
            padding="same",
            activation=tf.nn.relu,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
            name="d_c2"
        )
        
        # Average pooling
        p2_out = tf.layers.average_pooling2d(inputs=c2_out, pool_size=[2, 2], strides=[2,2], name="d_p2")

        # 1st fully-connected
        fc1_flat = tf.reshape(p2_out, [-1, 7 * 7 * 64])
        fc1_out = tf.layers.dense(inputs=fc1_flat, units=1024, activation=tf.nn.relu, name="d_fc1")

        # 2nd fully-connected
        fc2_out = tf.layers.dense(inputs=fc1_out, units=1, activation=None, name="d_fc2")

        return fc2_out

## Generator
The generator has the job to produce realistic looking images to ultimatley deceive the discriminator.

Generator architecture:

1. **```Input```**: accepts a noise vector of length $N$
2. **```1st fully-connected```**: upsamples the input vector to length $L$, applies ReLU activation function
3. **```1st convolutional```**: filtersize 3x3, 2x2 strides, $N/2$ features, ```SAME``` padding, ReLU activation
4. **[```batch normalization```](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm)**: using the default parameters
5. **```upsample```**: factor 2x2, bilinear
5. **```2nd convolutional```**: filtersize 3x3, 2x2 strides, $N/4$ features, ```SAME``` padding, ReLU activation
7. **```batch normalization```**: using the default parameters
8. **```upsample```**: factor 2x2, bilinear
9. **```3rd convolutional```**: filtersize 3x3, 2x2 strides, 1 feature, sigmoid activation


In [None]:
import tensorflow as tf
import numpy as np

def generator(batch_size, noise_dim):
    with tf.variable_scope("generator") as scope:
        
        # Input - Noise vector
        z = tf.random_normal(
            [batch_size, noise_dim],
            mean=0.0,
            stddev=1.0,
            dtype=tf.float32
        )
        
        # 1st fully-connected
        fc1_layer = tf.layers.dense(inputs=z, units=3136, activation=tf.nn.relu, name='gz_dense')
        fc1_out = tf.reshape(fc1_layer, (batch_size, 56, 56, 1)) # reshaping
        
        # 1st convolution
        c1_layer = tf.layers.conv2d(
            inputs=fc1_out,
            filters=noise_dim/2,
            kernel_size=[3, 3],
            strides=[2, 2],
            padding="same",
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
            activation=tf.nn.relu,
            name='gz_conv1',
            trainable=True
        )

        c1_out = tf.contrib.layers.batch_norm(inputs= c1_layer) # batch normalization

        c1_out = tf.image.resize_images(images=c1_out, size=(56, 56)) # upsampling
        
        # 2nd convolution
        c2_layer = tf.layers.conv2d(
            inputs=c1_out,
            filters=noise_dim/4,
            kernel_size=[3, 3],
            strides=[2, 2],
            padding="same",
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
            activation=tf.nn.relu,
            name='gz_conv2',
            trainable=True
        )

        c2_out = tf.contrib.layers.batch_norm(inputs=c2_layer)  # batch normalization

        c2_out = tf.image.resize_images(images=c2_out, size=(56, 56)) # upsampling

        # 3rd convolution
        c3_layer = tf.layers.conv2d(
            inputs=c2_out,
            filters=1,
            kernel_size=[3, 3],
            strides=[2, 2],
            padding="same",
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
            activation=tf.nn.sigmoid,
            name='gz_conv3',
            trainable=True
        )

        return c3_layer

## Building the GAN


In [None]:
noise_dim = 100
batch_size = 50

with tf.variable_scope("input"):
    x = tf.placeholder(tf.float32, shape=(None,28,28,1))

Gz = generator(batch_size, noise_dim)
Dx = discriminator(x, False)
Dg = discriminator(Gz, True)

## Training the GAN

In [None]:
# cross-entropy for the generator
gz_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg))
gz_loss = tf.reduce_mean (gz_loss)

# cross-entropy for the real-data-discriminator
dx_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx))
dx_loss = tf.reduce_mean (dx_loss)

# cross-entropy for the generated-data-discriminator
dg_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.zeros_like(Dg))
dg_loss = tf.reduce_mean (dg_loss)

# combined loss of the discriminators
d_loss = dg_loss + dx_loss 

In [None]:
# get the trainable variables
d_variables = tf.trainable_variables("discriminator")
gz_variables = tf.trainable_variables("generator")

# initialise the AdamOptimizers
d_optimizer = tf.train.AdamOptimizer(0.0003)
gz_optimizer = tf.train.AdamOptimizer(0.0001)

d_minimizer = d_optimizer.minimize(d_loss, var_list=d_variables)
gz_minimizer = d_optimizer.minimize(gz_loss, var_list=gz_variables)

### Tensorboard Implementation

In [None]:
from datetime import datetime

# use the current time for the tensorboard report
time = str(datetime.now()).replace(":", "-").replace(" ", "_")
time = time[:-7]
LOG_DIR = "./tensorboard/" + time + "/"
writer = tf.summary.FileWriter(LOG_DIR)

generatedImage = Gz

mergedSummaries_gz = tf.summary.merge([tf.summary.scalar("gz_loss", gz_loss)])
mergedSummaries_d = tf.summary.merge([tf.summary.scalar("dx_loss", dx_loss), tf.summary.scalar("dg_loss", dg_loss)])

generatedImages = tf.summary.merge([tf.summary.image("generatedImages", generatedImage, max_outputs = 5)])

### Actual Training

In [None]:
session = tf.Session()
init = tf.global_variables_initializer()
session.run(init)

## pretraining the discriminator ##
for i in range (1,301):
    #real images
    x_batch_real_images = mnist.train.next_batch(50)
    x_batch_real_images = np.reshape(x_batch_real_images[0],[50,28,28,1])
        
    session.run([d_minimizer, d_loss, mergedSummaries_d],feed_dict={x:x_batch_real_images})
    print("\r", "Pretraining Step " + str(i), end="")

print("\n pretraining completed. \n")

## real training ##
for i in range (1, 600001): 
    #real images
    x_batch_real_images = mnist.train.next_batch(50)
    x_batch_real_images = np.reshape(x_batch_real_images[0],[50,28,28,1])
        
    print("\r", "Training Step " + str(i), end="")
    
    # every 10th step add summary to tensorboard (see below)
    if (i % 10 == 0):
        
        # every 100th step add generated images to tensorboard 
        if(i % 100 == 0):
            _, curr_dg_loss, curr_dx_loss, summary_d = session.run([d_minimizer, d_loss, dx_loss, mergedSummaries_d],feed_dict={x:x_batch_real_images})
            _, curr_gz_loss, summary_gz, summary_generatedImages,generatedImage_log  = session.run([gz_minimizer, gz_loss,  mergedSummaries_gz, generatedImages, generatedImage])
            
            # add generated images to tensorboard
            writer.add_summary(summary_generatedImages, i)
            
            # every 500th step plot ten generated images and print the loss
            if (i % 500 == 0 ):
                print("\n _________________________________\n")
                print("\n  Iteration:", i)
                print ("\n Loss Dx: ", curr_dx_loss);
                print ("\n Loss Dg: ", curr_dg_loss);
                print ("\n Loss Gz: ", curr_gz_loss);
                def plot_output():
                    plt.figure(figsize=(10,10))
                    for j in range (10):
                        z_test = np.random.normal(-1, 1, size=[1, 100])
                        plt.subplot(10, 10, j+1)
                        imgx = generatedImage_log[j]
                        imgx = imgx.reshape(28,28)
                        plt.imshow(X=imgx, cmap='gray_r')
                        plt.axis('off')
                        plt.tight_layout()
                plt.show(plot_output())
            
        else:
            _, summary_d = session.run([d_minimizer, mergedSummaries_d],feed_dict={x:x_batch_real_images})
            _, summary_gz = session.run([gz_minimizer, mergedSummaries_gz])
    
        # add generator summary for tensorboard
        writer.add_summary(summary_d, i) 
        # add discriminator summary for tensorboard
        writer.add_summary(summary_gz, i) 
            
    else:
        session.run([d_minimizer, d_loss, dx_loss, mergedSummaries_d],feed_dict={x:x_batch_real_images})
        session.run([gz_minimizer, gz_loss,  mergedSummaries_gz, generatedImages])

session.close()
        

---
# <font color='blue'> Summary GAN: </font>

Building the discriminator and the generator was rather straight forward, but getting good results was difficult. The results of the first test runs were not satisfying. Therefore we tried different parameter combinations. This is tedious since we had to wait for the GAN to deliver a representative outcome. Usually this took at least 50.000 iterations.

Interpreting the three different loss values we record was difficult. It was hard to discern what value combinations would result in good images. The GAN we currently use is prone to overfitting. After a certain time it only generates one type of digit.

<img src="img/overfitting.png" width="500px;">
> _Typical overfitting which can already start at 60k iterations_

Unfortunately it is difficult to decide when to stop training the network to avoid overfitting. Overfitting would often start when we get the first recognizable images. Since there are many different parameters we could adjust, testing their impact in a structured manner would take a lot of time. Therefore we picked some variables at random and tried to improve the result through them.

**Testing different Parameters**

We tried different amounts of pretraining steps for the discriminator. Values between 300 and 500 seem to lead to the best results later on. Therefore we left this parameter at 300 as suggested.

Next we changed the amount of features in the convolution layers of the discriminator. This helped to counter overfitting. The overall quality of the generated images decreased slightly in general, but also has some outliers which are barely recognizable. The recorded loss of the discriminators for the real and fake images was significantly higher (from 0.003 in the original version to 0.3 with fewer features). The loss of the generator was noticably smaller (from about 7.0 to 1.8) and was more stable during the training. At about 80.000 iterations, the GAN begins to stagnate and does not improve much.
        
<img src="img/Run1_after_about_80k_steps.png" width="500px;">

<img src="img/Run1_afters_about_150steps.png" width="500px;">

> _Output of the modified GAN at ~80k and ~150k iterations: the GAN does not show a significant improvement but no overfitting occurs _

<img src="img/losses_Run1vsOrig.jpg">
> _Different losses of the modified GAN in comparison to the 'original' GAN_
 
Decreasing the filter size from 5x5 to 3x3 in the convolution layers has no significant effect. Overfitting still occurs.
    
<img src="img/Run2_overfitting_after_80k_steps.png" width="500px;">
> _ Overfitting after 80k steps_

Since decreasing the amount of features prooved effective to counter overfitting, we wanted to combine this with other changes. At first, we multiplied the learning rate by 10. This lead to quicker results as expected, but at about 50.000 iterations overfitting occurs despite the changes to the discriminator. The losses are between the original run the run with fewer features in the discriminator. Then we tried to decrease the learning rate by a factor of 10. The results were worse and the generated digits barely recognizable.

<img src="img/run4.png" width="500px;">

**Conclusion**

After some experiments with the variables we identified the following patterns:

* Pattern 1: Low dg/dx-loss in combination with high gz-loss leads to perfectly recognisable digits but prone to overfitting
* Pattern 2: High dg/dx-loss in combination with low gz-loss leads to hardly recognisable digits but no overfitting

Finding a balance between these patterns is very difficult but crucial for the performance of the GAN. The following changes lead to these patterns:

* More features in the discriminator - Pattern 1 (overfitting)
* Less features in the discriminator - Pattern 2 (low quality images)
* Increasing the learning rate of the AdamOptimizer - Pattern 1 (overfitting)
* Decreasing the learning rate of the AdamOptimizer - Pattern 2 (low quality images)
* Increasing steps of pretraining - Pattern 1 (overfitting)
* Decreasing steps of pretraining - Pattern 2 (low quality images)

<img src="img/features-losses.jpg" width="1200px;">
> _Resulting losses after running the GAN multiple times with different parameters_


**Further Ideas**

Instead of creating one GAN for all ten digits, we could create a different GAN for every digit. This should increase the quality of the generated digits. But it is not the intention of GANs to generate images this way.

If overfitting occurs, we could withhold all images of that particular digit from the dataset. This might cause the discriminator to 'forget' this digit and force the generator to come up with another digit. But it will probably be difficult to have functioning GAN which can create all digits with about equal probability. Maybe implementing a dropout after the convolutional layers could be helpful to avoid overfitting and a better generalisation as well.

It would be helpful to have a classifier for the output of the generator, to be able to investigate the statistical distribution of the generated digits. This can also be used to automatically detect overfitting.

In theory a loss of 0.5 for the discriminator should be ideal. This means that the discriminator has a 50/50 probability to classify an image as real or fake. It cannot discern between them, which is what we want.

---

## Further Ideas

### Change the Loss Function to Wasserstein Distance
[Wasserstein GANs (WGAN)](https://arxiv.org/pdf/1701.07875.pdf) are an alternative to *classical* GANs. They use a different loss function and prooved to be more stable to hyperparameter selection. While the paper offers a good theoretical introduction and reasoning why they perform better in many cases, [this article](https://wiseodd.github.io/techblog/2017/02/04/wasserstein-gan/) provides a good practical introduction