This is an illustration of a very simple generative adversarial network, built with TensorFlow. It generates images that look like handwritten digits from the MNIST dataset.

For the greatest possible clarity, I've adapted two well-documented networks as the discriminator and the generator. [The convolutional neural network from TensorFlow's documentation](https://www.tensorflow.org/tutorials/mnist/pros/) serves as the discriminator, and [Tim O'Shea's Keras model](http://www.kdnuggets.com/2016/07/mnist-generative-adversarial-model-keras.html) as the generator.

Other crucial insights come from papers by [Ian Goodfellow](https://arxiv.org/abs/1701.00160) and [Alec Radford](https://arxiv.org/abs/1511.06434), and [Soumith Chintala](https://github.com/soumith/ganhacks).

**This is a work in progress**, and is full of all manner of hacks and hard-coded shortcuts that will disappear or (hopefully) become more elegant as I make revisions.

The code here is written for TensorFlow v0.12, but can be made to run on earlier versions with some quick changes—in particular, replacing `tf.global_variable_initializer()` with `tf.initialize_all_variables()`. This script sends very helpful output to TensorBoard; to make it work with TensorBoard v0.11 and earlier, replace `tf.summary.scalar()` and `tf.summary.image()` with `tf.scalar_summary()` and `tf.image_summary()`, respectively.

In [None]:
%matplotlib inline
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow.contrib.learn.python.learn.datasets.mnist as mn
mnist = input_data.read_data_sets('MNIST_data/', one_hot=False)
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import datetime

In [None]:
# Fix disused scope stuff here

def weight_variable(shape, name, sc='discriminator'):
    return tf.get_variable(name, shape, initializer=tf.truncated_normal_initializer(0, 0.1))

def bias_variable(shape, name, sc='discriminator'):
    return tf.get_variable(name, shape, initializer=tf.constant_initializer(0.1))

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x, name):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME', name=name)

In [None]:
with tf.variable_scope('meta_variables') as scope:
    # Hack to show in TensorBoard whether discriminator or generator is being trained
    g_train_count = tf.placeholder(tf.int8)
    tf.summary.scalar('Generator_training_cycles', g_train_count)
    d_generated_train_count = tf.placeholder(tf.int8)
    tf.summary.scalar('Discriminator_generated_training_cycles', d_generated_train_count)
    d_mnist_train_count = tf.placeholder(tf.int8)
    tf.summary.scalar('Discriminator_mnist_training_cycles', d_mnist_train_count)

Here's the generator network; it includes a version of the discriminator so that gradients over the discriminator with respect to the generator weights are available to the generator's optimizer.

In [None]:
with tf.variable_scope('generator') as scope:
    z = tf.placeholder(tf.float32, shape=[None, 100], name="z")

    w1 = tf.Variable(tf.random_uniform([100, 3136], 0, 1), name='w1')
    b1 = tf.Variable(tf.zeros([3136]), name='b1')
    g1 = tf.matmul(z, w1) + b1
    g1 = tf.reshape(g1, [-1, 56, 56, 1])
    g1 = tf.contrib.layers.batch_norm(g1, epsilon=1e-5)
    g1 = tf.nn.relu(g1)

    # Generate 50 features
    w2 = tf.Variable(tf.random_uniform([3, 3, 1, 50]))
    b2 = tf.Variable(tf.random_uniform([50]))
    g2 = tf.nn.conv2d(g1, w2, strides=[1, 2, 2, 1], padding='SAME')
    g2 = g2 + b2
    g2 = tf.contrib.layers.batch_norm(g2, epsilon=1e-5)
    g2 = tf.nn.relu(g2)
    g2 = tf.image.resize_images(g2, [56, 56])

    # Generate 25 features
    w3 = tf.Variable(tf.random_uniform([3, 3, 50, 25]))
    b3 = tf.Variable(tf.random_uniform([25]))
    g3 = tf.nn.conv2d(g2, w3, strides=[1, 2, 2, 1], padding='SAME')
    g3 = g3 + b3
    g3 = tf.contrib.layers.batch_norm(g3, epsilon=1e-5)
    g3 = tf.nn.relu(g3)
    g3 = tf.image.resize_images(g3, [56, 56])

    # Final convolution with one output channel
    w4 = tf.Variable(tf.random_uniform([1, 1, 25, 1]))
    b4 = tf.Variable(tf.random_uniform([1]))
    g4 = tf.nn.conv2d(g3, w4, strides=[1, 2, 2, 1], padding='SAME')
    g4 = tf.sigmoid(g4)

    #Using tf.squeeze to eliminate final dimension that would usually be used for color channels
    generator_images = tf.reshape(tf.squeeze(g4), [50, 784])
    
    generator_trainable_variables = [w1, b1, w2, b2, w3, b3, w4, b4]

    # Per Goodfellow, the last layer of the generator is not normalized.
    
    x = generator_images
    images_for_tensorboard = tf.reshape(generator_images, [-1, 28, 28, 1])
    tf.summary.image('Generated_images', images_for_tensorboard, max_outputs=50)
    
    # Let's combine the MNIST images and labels via placeholder with the generated images
    
    W = tf.Variable(tf.zeros([784, 1]), name="W")
    b = tf.Variable(tf.zeros([1]), name="b")

    W_conv1 = weight_variable([5, 5, 1, 32], 'W_conv1')
    b_conv1 = bias_variable([32], 'b_conv1')

    x_image = tf.reshape(x, [-1,28,28,1])

    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1, name='h_conv1')
    h_pool1 = max_pool_2x2(h_conv1, name='h_pool1')

    W_conv2 = weight_variable([5, 5, 32, 64], 'W_conv2')
    b_conv2 = bias_variable([64], name='b_conv2')

    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2, name='h_conv2')
    h_pool2 = max_pool_2x2(h_conv2, name='h_pool2')

    W_fc1 = weight_variable([7 * 7 * 64, 1024], name='W_fc1')
    b_fc1 = bias_variable([1024], name='b_fc1')

    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64], name='h_pool2_flat')
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1, name='h_fc1')

    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob, name='h_fc1_drop')

    W_fc2 = weight_variable([1024, 1], name='W_fc2')
    b_fc2 = bias_variable([1], name='b_fc2')

    y_conv = tf.sigmoid(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

And here's the discriminator network. Note that we don't reinitialize the weights and biases; we use the same weights and biases that we initialized for the generator network, but describe new layers for them that take an input from a placeholder, `d_x`, instead of taking input straight from the generator.

In [None]:
with tf.variable_scope('discriminator') as scope:
    d_x = tf.placeholder(tf.float32, shape=[None, 784], name='d_x')
    # d_y_ = tf.placeholder(tf.float32, shape=[None, 2], name='d_y_')

    d_x_image = tf.reshape(d_x, [-1,28,28,1], name='d_x_image')

    d_h_conv1 = tf.nn.relu(conv2d(d_x_image, W_conv1) + b_conv1, name='d_h_conv1')
    d_h_pool1 = max_pool_2x2(d_h_conv1, name='d_h_pool1')

    d_h_conv2 = tf.nn.relu(conv2d(d_h_pool1, W_conv2) + b_conv2, name='d_h_conv2')
    d_h_pool2 = max_pool_2x2(d_h_conv2, name='d_h_pool2')

    d_h_pool2_flat = tf.reshape(d_h_pool2, [-1, 7*7*64], name='d_h_pool2_flat')
    d_h_fc1 = tf.nn.relu(tf.matmul(d_h_pool2_flat, W_fc1) + b_fc1, name='d_h_fc1')

    d_h_fc1_drop = tf.nn.dropout(d_h_fc1, keep_prob, name='d_h_fc1_drop')

    d_y_conv = tf.sigmoid(tf.matmul(d_h_fc1_drop, W_fc2) + b_fc2)
    
    # There's probably a clever way to refer to the losses of G and D
    # rather than these accuracy stats.
    generated_accuracy = 1 - tf.reduce_mean(d_y_conv[0:50])
    tf.summary.scalar('generated_accuracy', generated_accuracy)
    mnist_accuracy = tf.reduce_mean(d_y_conv[50:100])
    tf.summary.scalar('mnist_accuracy', mnist_accuracy)
    
    discriminator_trainable_variables = [W_conv1, b_conv1, W_conv2, b_conv2,
                                     W_fc1, b_fc1, W_fc2, b_fc2]

In [None]:
y_ = tf.placeholder(tf.float32, shape=[None, 1], name='y_')

discriminator_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d_y_conv, y_))

tf.summary.scalar('discriminator_loss', discriminator_loss)

generator_loss = -1 * tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(y_conv, y_))
tf.summary.scalar('generator_loss', generator_loss)

discriminator_optimize = tf.train.AdamOptimizer(1e-5, name='discriminator_train') \
    .minimize(discriminator_loss, var_list=discriminator_trainable_variables)
generator_optimize = tf.train.AdamOptimizer(1e-3, name='generator_train') \
    .minimize(generator_loss, var_list=generator_trainable_variables)

generator_weights_mean = tf.reduce_mean(tf.abs(w1)) + tf.reduce_mean(tf.abs(w2)) + tf.reduce_mean(tf.abs(w3)) + tf.reduce_mean(tf.abs(w4))
tf.summary.scalar('generator_weights_mean', generator_weights_mean)

In [None]:
sess = tf.InteractiveSession()

merged = tf.summary.merge_all()
logdir = "tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/"
writer = tf.summary.FileWriter(logdir, sess.graph)

sess.run(tf.global_variables_initializer())

We want to eventually reach a point where the discriminator correctly classifies all real MNIST images as MNIST images, and classifies generated images as MNIST images about 50% of the time. There are several failure modes that we need to avoid:
* **Discriminator accuracy goes to 100%**: this leaves practically no gradients for the generator's optimizer.
* **Discriminator accuracy goes to 0%**
* **Divergent discriminator accuracy**: the discriminator classifies generated images at about 50%, but drops toward 0% accuracy on real MNIST images.

To stay balanced between these, we have separate functions to train the generator, train the discriminator on real MNIST images, and train the discriminator on generated images.

In [None]:
def train_discriminator_generated(combined_labels, generated_labels):
    # Grab some generated images
    random = np.random.normal(size=(50, 100))
    mnist_batch = mnist.train.next_batch(50)[0]
    g_images = sess.run(generator_images, {z: random})
    
    _ = sess.run(discriminator_optimize, {d_x: g_images, keep_prob: 0.5, y_: generated_labels})
    
    random = np.random.normal(size=(50, 100))
    g_images = sess.run(generator_images, {z: random})
    mnist_batch = mnist.validation.next_batch(50)[0]
    combined_images = np.concatenate((g_images, mnist_batch))
    
    ga, ma = sess.run([generated_accuracy, mnist_accuracy], {d_x: combined_images, keep_prob: 1})
    return ga, ma

def train_discriminator_mnist(combined_labels, mnist_labels):
    mnist_batch = mnist.train.next_batch(50)[0]
    
    _ = sess.run(discriminator_optimize, {d_x: mnist_batch, keep_prob: 0.5, y_: mnist_labels})
    
    random = np.random.normal(size=(50, 100))
    g_images = sess.run(generator_images, {z: random})
    mnist_batch = mnist.validation.next_batch(50)[0]
    combined_images = np.concatenate((g_images, mnist_batch))
    
    ga, ma = sess.run([generated_accuracy, mnist_accuracy], {d_x: combined_images, keep_prob: 1})
    return ga, ma

def train_discriminator_combined(combined_labels):
    random = np.random.normal(size=(50, 100))
    mnist_batch = mnist.train.next_batch(50)[0]
    g_images = sess.run(generator_images, {z: random})
    combined_images = np.concatenate((g_images, mnist_batch))
    
    _ = sess.run(discriminator_optimize, {d_x: combined_images, keep_prob: 0.5, y_: combined_labels})
    
    random = np.random.normal(size=(50, 100))
    mnist_batch = mnist.train.next_batch(50)[0]
    g_images = sess.run(generator_images, {z: random})
    combined_images = np.concatenate((g_images, mnist_batch))
    
    ga, ma = sess.run([generated_accuracy, mnist_accuracy], {d_x: combined_images, keep_prob: 1})
    return ga, ma

def train_generator(generated_labels):
    random = np.random.normal(size=(50, 100))
    g_images, gl, _ = sess.run([generator_images, generator_loss, generator_optimize], {z: random, y_: generated_labels, keep_prob: 1})
    
    mnist_batch = mnist.validation.next_batch(50)[0]
    combined_images = np.concatenate((g_images, mnist_batch))
    ga, ma = sess.run([generated_accuracy, mnist_accuracy], {d_x: combined_images, keep_prob: 1})
    return g_images, ga, ma

def get_combined_images():
    random = np.random.normal(size=(50, 100))
    g_images = sess.run(generator_images, {z: random})
    validation_images = mnist.validation.next_batch(50)[0]
    return np.concatenate((g_images, validation_images))

If the discriminator's accuracy over generated images is less than 50%, we train the discriminator over generated images. If the discriminator's accuracy over real MNIST images is less than 80%, we train it over MNIST images. Otherwise, we train the generator.

In my experience, the generator winds up being trained almost exclusively for the first several thousand iterations, then settles into a balance where the iterations are given over about 45% to generator training, 35% to discriminator training over MNIST images, and 20% discriminator training over generated images. The greatest progress in developing recognizable images is made once the system reaches this balance.

Summary statistics illustrating the ratio of discriminator and generator training cycles are sent to TensorBoard. Recognizable digits begin to emerge after about 20,000 iterations, and improve markedly for the 50,000 or so iterations after that. Every thousand iterations takes about 10 minutes on my laptop, or one minute on an [AWS P2 GPU-enabled machine](https://aws.amazon.com/ec2/instance-types/p2/), so a full 100,000 iterations should take 16 hours on a laptop or an hour and a half on a P2 instance.

In [None]:
# It's useful to track changes in the generator's weights
# to see how quickly it's being optimized.
last_weights = sess.run([w1, w2, w3, w4])

generated_labels = np.array([[0.] for i in range(50)])
# We try to get the discriminator to classify MNIST images as 0.9 rather than 1.0;
# this is label smoothing, described by Goodfellow.
mnist_labels = np.array([[0.9] for i in range(50)])
combined_labels = np.concatenate((generated_labels, mnist_labels))

# Start by running the generator
g_images, ga, ma = train_generator(generated_labels)

# ga = discriminator accuracy over generated images
# ma = discriminator accuracy over real MNIST images

d_generated_train_counter = 0
d_mnist_train_counter = 0
g_train_counter = 0

for i in range(100000):
    if ga < 0.5:
        # Train discriminator over generated images
        ga, ma = train_discriminator_generated(combined_labels, generated_labels)
        d_generated_train_counter += 1
    elif ma < 0.8:
        # Train discriminator over real MNIST images
        ga, ma = train_discriminator_mnist(combined_labels, mnist_labels)
        d_mnist_train_counter += 1
    else:
        # Train generator
        g_images, ga, ma = train_generator(generated_labels)
        g_train_counter += 1

    if i % 10 == 0:
        # Every 10 iterations, send summary statistics to TensorBoard
        random = np.random.normal(size=(50, 100))
        g_images = sess.run(generator_images, {z: random})
        summary = sess.run(merged, {d_x: g_images, keep_prob: 1, z: np.random.normal(size=(50, 100)),
                                    y_: generated_labels, g_train_count: g_train_counter,
                                    d_generated_train_count: d_generated_train_counter,
                                    d_mnist_train_count: d_mnist_train_counter})
        writer.add_summary(summary, i)
        d_mnist_train_counter, d_generated_train_counter, g_train_counter = 0, 0, 0
    if i % 100 == 0:
        combined_images = get_combined_images()

        current_weights = sess.run([w1, w2, w3, w4])
        ga, ma = sess.run([generated_accuracy, mnist_accuracy], {d_x: combined_images, keep_prob: 1})
        print(i, "MNIST acc:", ma, "Gen acc:", ga, "at", datetime.datetime.now())
        
        # Summarize change in the generator's weights. If this tends toward zero
        # within the first 100,000 iterations, there's something wrong.
        weight_diff_sum = 0
        for i in range(4):
            weight_diff_sum += np.mean(np.absolute(np.absolute(last_weights[i]) - np.absolute(current_weights[i])))
        print("Weight differences:", weight_diff_sum)
        last_weights = current_weights

Now let's see some of the images produced by the generator. (The generator has also been sending its images to TensorBoard regularly; click the "images" tab in TensorBoard to see them as this runs.)

In [None]:
random = np.random.normal(size=(50, 100))
g_images = sess.run(generator_images, {z: random}).squeeze().reshape([50, 784])
d_classifications = sess.run(d_y_conv, {d_x: g_images, keep_prob: 1})

for i in range(50):
    # Print the discriminator's classification of this image
    print(d_classifications[i])
    plt.imshow(g_images[i].reshape([28, 28]), cmap='Greys', interpolation='none')
    plt.show()

And, as a sanity check, let's look at some real MNIST images and make sure that the discriminator correctly classifies them as real MINST images.

In [None]:
d_classifications = sess.run(d_y_conv, {d_x: validation_batch, keep_prob: 1})
print(d_classifications)
for i in range(50):
    img = validation_batch[i]
    print(d_classifications[i])
    plt.imshow(img.reshape([28, 28]), cmap='Greys')
    plt.show()
    