In [None]:
from __future__ import print_function
from __future__ import absolute_import

import os
import sys
import tensorflow as tf
import numpy as np

import datetime

from matplotlib import pyplot as plt
from IPython import display

# iPython specific
%matplotlib inline

In [None]:
# DLTK imports
from dltk.models.gan.dcgan import DCGAN
from dltk.core.modules import *

# MNIST data import
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../../data/MNIST_data', one_hot=False)
save_path =  '/tmp/GAN_MNIST'
os.system("rm -rf " + save_path)

# Set environmental variables to select CUDA device
os.environ["CUDA_VISIBLE_DEVICES"]='0'
tf.logging.set_verbosity(tf.logging.ERROR)

with tf.Graph().as_default():
    
    # Define inputs
    xp = tf.placeholder(tf.float32, shape=[None, 784])

    x_in = tf.reshape(xp, [-1, 28, 28, 1])
    z_in = tf.placeholder(tf.float32, shape=[None, 100])
    
    # Build a DCGAN model
    net = DCGAN(discriminator_filters=(16, 32, 64, 128),
                discriminator_strides=((1, 1), (2, 2), (1, 1), (2, 2)),
                generator_filters=(256, 128, 64, 1),
                generator_strides=((7, 7), (2, 2), (2, 2), (1, 1)),
                relu_leakiness=0.2)
    
    # Build the computation graph with given inputs
    train_out = net(z_in, x_in)
    
    # Use ADAM as optimisers
    g_op = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(train_out['g_loss'],
                                                              var_list=net.gen.get_variables())
    d_op = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(train_out['d_loss'],
        
    var_list=net.disc.get_variables())
    global_step = tf.Variable(0, name='global_step', trainable=False)
    tf.add_check_numerics_ops()
    
    # Set up the supervisor
    step = 0
    dloss_moving = []
    gloss_moving = []
    sv = tf.train.Supervisor(logdir=None,
                             is_chief=True,
                             summary_op=None,
                             save_summaries_secs=30,
                             save_model_secs=60,
                             global_step=global_step)

    s = sv.prepare_or_wait_for_session(config=tf.ConfigProto())
    
    # Main training loop
    for _ in range(20000):
        
        # Get a real example batch and normalise to [-1, 1]
        batch = (mnist.train.next_batch(100)[0] * 2) - 1
        
        # Get noise sample for the generator
        noise = np.random.uniform(-1, 1, [100, 100]).astype(np.float32)

        # Update the discriminator
        (_, dc) = s.run([d_op, train_out['d_loss']], feed_dict={xp: batch, z_in: noise})
        dloss_moving.append(dc)
        
        # Update the generator
        (_, gc) = s.run([g_op, train_out['g_loss']], feed_dict={xp: batch, z_in: noise})
        gloss_moving.append(gc)

        # Plot a sample and the loss every 20 steps
        if step % 20 == 0:
            plt.close()
            display.clear_output(wait=True)
            
            # Generate a sample
            sample = s.run(train_out['gen']['gen'], feed_dict={z_in: noise})
            
            f, axarr = plt.subplots(1, 2, figsize=(16,4))
            axarr[0].imshow(np.squeeze(sample[0,:,:]), cmap='gray', vmin=-1, vmax=1,)
            axarr[0].set_title('Sample')
            axarr[0].axis('off')
            
            axarr[1].plot(dloss_moving, label='discriminator loss')
            axarr[1].plot(gloss_moving, label='generator loss')
            axarr[1].set_title('Loss')
            axarr[1].axis('on')
            axarr[1].legend()
            display.display(plt.gcf())
            
        step += 1
    
# Close plot for cleanup
plt.close()