# MNIST Deep Convolutional Generative Adversarial Network
A Generative Adversarial Network (GAN) is capable of finding the probability distribution of a dataset and sample from it to produce new images very close to the original dataset.

In [None]:
import tensorflow as tf
import numpy as np
import os
from PIL import Image

## Data Preprocessing for MNIST

In [None]:
# Load MNIST data from tf examples

image_height = 28
image_width = 28

color_channels = 1

model_name = "mnist"

mnist = tf.contrib.learn.datasets.load_dataset("mnist")

train_data = mnist.train.images
train_data = np.reshape(train_data, (-1, image_height, image_width, color_channels))
 
print(train_data.shape)

## Creating the discriminator network
The discriminator network takes images as an input and is trained to find the probability that the given image is fake, or created by the generator, or real, pulled out of the dataset. It uses convolutional layers, leaky relu activations and batch normalization to increase training speed.

In [None]:
# Discriminator
def create_discriminator(image, reuse=False):

    with tf.variable_scope("discriminator", reuse=reuse) as scope:
        w1 = tf.get_variable("conv1weights", 
                             shape=[5, 5, color_channels, 16], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b1 = tf.get_variable("conv1biases", 
                             shape=[16],  
                             initializer=tf.constant_initializer(0))
        conv1 = tf.nn.conv2d(image, w1, [1, 1, 1, 1], "SAME") + b1
        norm_conv1 = tf.nn.leaky_relu(tf.contrib.layers.batch_norm(conv1, epsilon=1e-5))
        pool1 = tf.layers.max_pooling2d(norm_conv1, [2, 2], [2, 2], "SAME")
        
        w2 = tf.get_variable("conv2weights", 
                             shape=[5, 5, 16, 32], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b2 = tf.get_variable("conv2biases", 
                             shape=[32], 
                             initializer=tf.constant_initializer(0))
        conv2 = tf.nn.conv2d(pool1, w2, [1, 1, 1, 1], "SAME") + b2
        norm_conv2 = tf.nn.leaky_relu(tf.contrib.layers.batch_norm(conv2, epsilon=1e-5))
        pool2 = tf.layers.max_pooling2d(norm_conv2, [2, 2], [2, 2], "SAME")
        
        w3 = tf.get_variable("conv3weights", 
                             shape=[5, 5, 32, 64], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b3 = tf.get_variable("conv3biases", 
                             shape=[64], 
                             initializer=tf.constant_initializer(0))
        conv3 = tf.nn.conv2d(pool2, w3, [1, 1, 1, 1], "SAME") + b3
        norm_conv3 = tf.nn.leaky_relu(tf.contrib.layers.batch_norm(conv3, epsilon=1e-5))
        pool3 = tf.layers.max_pooling2d(norm_conv3, [2, 2], [2, 2], "SAME")
        flatten = tf.layers.flatten(pool3)
        
        w5 = tf.get_variable("dense1weights", 
                             shape=[1024, 1], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b5 = tf.get_variable("dense1biases", 
                             shape=[1, 1], 
                             initializer=tf.constant_initializer(0))
        output = tf.sigmoid(tf.matmul(flatten, w5) + b5)
        
    return output

## Creating the generator network
The generator network takes in a sample 'z' from the normal distribution as a 100-D vector and applies transpose convolutions with relu and, once again, batch normalization to essentially transform the normal distribution into the distribution of the dataset.

In [None]:
# Generator
def create_generator(z, reuse=False):
    with tf.variable_scope("generator", reuse=reuse) as scope:
        batch_size = tf.shape(z)[0]
        
        w1 = tf.get_variable("dense1weights", 
                             shape=[noise_length, 4096], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b1 = tf.get_variable("dense1biases", 
                             shape=[1, 4096], 
                             initializer=tf.constant_initializer(0))
        dense1 = tf.matmul(z, w1) + b1
        norm_dense1 = tf.nn.relu(tf.contrib.layers.batch_norm(dense1))
        
        
        conv_input = tf.reshape(norm_dense1, shape=[-1, 4, 4, 256])
        
        w2 = tf.get_variable("conv1weights", 
                             shape=[5, 5, 64, 256], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b2 = tf.get_variable("conv1biases",
                             shape=[64],
                             initializer=tf.constant_initializer(0))
        conv1 = tf.nn.conv2d_transpose(conv_input, w2, [batch_size, 8, 8, 64], [1, 2, 2, 1]) + b2
        norm_conv1 = tf.nn.relu(tf.contrib.layers.batch_norm(conv1))
        
        w3 = tf.get_variable("conv2weights", 
                             shape=[5, 5, 32, 64], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b3 = tf.get_variable("conv2biases",
                             shape=[32],
                             initializer=tf.constant_initializer(0))
        conv2 = tf.nn.conv2d_transpose(norm_conv1, w3, [batch_size, 16, 16, 32], [1, 2, 2, 1]) + b3
        norm_conv2 = tf.nn.relu(tf.contrib.layers.batch_norm(conv2))
        
        w4 = tf.get_variable("conv3weights", 
                             shape=[5, 5, 16, 32], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b4 = tf.get_variable("conv3biases",
                             shape=[16],
                             initializer=tf.constant_initializer(0))
        conv3 = tf.nn.conv2d_transpose(norm_conv2, w4, [batch_size, 32, 32, 16], [1, 2, 2, 1]) + b4
        norm_conv3 = tf.nn.relu(tf.contrib.layers.batch_norm(conv3))
        
        w5 = tf.get_variable("conv4weights", 
                             shape=[32, 32, color_channels, 16], 
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b5 = tf.get_variable("conv4biases",
                             shape=[color_channels],
                             initializer=tf.constant_initializer(0))
        conv4 = tf.nn.conv2d_transpose(norm_conv3, w5, [batch_size, 32, 32, color_channels], [1, 1, 1, 1]) + b5
        output = tf.sigmoid(conv4)
        
    return output

# Adversarial loss
The discriminator is given both real images from the MNIST dataset and fake images created by the generator. The discriminator is optimized to distinguish between real and fake, while the generator is trying to maximize the loss of the discriminator. This creates an adversary, making convergence extremely difficult, however allowing the generator to get closer and closer to finding the probability distribution of the data.

In [None]:
# Building the model

tf.reset_default_graph()

img_height = 32
img_width = 32
color_channels = 1
noise_length = 100

# Placeholders
x_images = tf.placeholder(dtype=tf.float32, shape=[None, img_height, img_width, color_channels])
noise = tf.placeholder(dtype=tf.float32, shape=[None, noise_length])

# Minimax
real = create_discriminator(x_images)
fake = create_discriminator(create_generator(noise), reuse=True)

gen_loss = -tf.reduce_mean(tf.log(fake))
disc_loss = -tf.reduce_mean(tf.log(real) + tf.log(1 - fake))

gen_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="generator")
disc_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="discriminator")

train_gen = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(gen_loss, var_list=gen_vars)
train_disc = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(disc_loss, var_list=disc_vars)

## Training the model

In [None]:
# Training
load_checkpoint = True
path = "GAN checkpoints/"
saver = tf.train.Saver(max_to_keep=8)

batch_size = 100
epochs = 50000
display_step = 10

init = tf.global_variables_initializer()
sess = tf.Session()
if load_checkpoint:
    checkpoint = tf.train.get_checkpoint_state(path)
    saver.restore(sess, checkpoint.model_checkpoint_path)
else:
    sess.run(tf.global_variables_initializer())

current_batch_index = 0

test_gen = create_generator(np.random.uniform(-1.0, 1.0, size=[16, noise_length]).astype(np.float32), reuse=True)

for epoch in range(epochs):
    batch_xs = np.array(train_data[current_batch_index:current_batch_index + batch_size])
    batch_xs = np.lib.pad(batch_xs, ((0,0),(2,2),(2,2),(0,0)),'constant')
    
    if current_batch_index + batch_size >= len(train_data):
        current_batch_index = 0
    else:
        current_batch_index += batch_size
        
    zs = np.random.uniform(-1.0, 1.0, size=[batch_size, noise_length]).astype(np.float32)
        
    if epoch % display_step == 0:
        a = np.array(sess.run(test_gen))
        save_sample(a, "Generated/" + str(epoch) + ".bmp", [4, 4])
        saver.save(sess, path + model_name, epoch)
        
    sess.run(train_disc, feed_dict={x_images: batch_xs, noise: zs})
    sess.run(train_gen, feed_dict={noise: zs})
    sess.run(train_gen, feed_dict={noise: zs})
    
    print("Epoch", 
          epoch, 
          "Generator Loss", 
          sess.run(gen_loss, 
                   feed_dict={noise: zs}), 
          "Discriminator Loss", 
          sess.run(disc_loss, 
                   feed_dict={x_images: batch_xs, noise: zs}))
    
saver.save(sess, path + model_name, epoch)

## Image saving function

In [None]:
# Saving image samples
def save_sample(images_array, filename, shape):
    
    img_width = images_array.shape[1]
    img_height = images_array.shape[2]
    
    final_width = img_width * shape[0]
    final_height = img_width * shape[1]
    
    final_arr = np.zeros((final_width, final_height))
    
    for i in range(len(images_array)):
        x = int(i % shape[0]) * img_width
        y = int(i / shape[0]) * img_height
        
        final_arr[x:x + img_width, y:y + img_height] = images_array[i].reshape(img_height, img_width)
        
    final_img = Image.fromarray((final_arr * 255).astype(np.uint8), mode="L")
    final_img.save(filename)