In [None]:
from __future__ import print_function
from __future__ import absolute_import

import os
import sys
import tensorflow as tf
import numpy as np

import datetime

from matplotlib import pyplot as plt
from IPython import display
from functools import partial
# iPython specific
%matplotlib inline

In [None]:
# DLTK imports
from dltk.models.gan.wdcgan import WDCGAN
from dltk.core.modules import *

# MNIST data import
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../../data/MNIST_data', one_hot=False)
save_path =  '/tmp/WGAN_MNIST'
os.system("rm -rf " + save_path)

# Set environmental variables to select CUDA device
os.environ["CUDA_VISIBLE_DEVICES"]='0'
tf.logging.set_verbosity(tf.logging.ERROR)

# Training settings
disc_updates = 5
gen_updates = 1
improved = True
batch_size = 50


with tf.Graph().as_default():
    
    # Define inputs
    xp = tf.placeholder(tf.float32, shape=[None, 784])

    x_in = tf.reshape(xp, [-1, 28, 28, 1])
    z_in = tf.placeholder(tf.float32, shape=[None, 100])
    
    # Set clipping base on method
    if improved:
        clip_val = 10.0
    else:
        clip_val = 0.01
        
    # Build a WGAN model
    net = WDCGAN(discriminator_filters=(16, 32, 64, 128),
                discriminator_strides=((1, 1), (2, 2), (1, 1), (2, 2)),
                generator_filters=(256, 128, 64, 1),
                generator_strides=((7, 7), (2, 2), (2, 2), (1, 1)),
                relu_leakiness=0.2,
                improved=improved,
                clip_val=clip_val)
    
    # Build the computation graph with given inputs
    train_out = net(z_in, x_in)
    
    # Add an optimiser depending on method
    if improved:
        g_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(train_out['g_loss'], var_list=net.gen.get_variables())
        d_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(train_out['d_loss'], var_list=net.disc.get_variables())
    else:
        g_op = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(train_out['g_loss'], var_list=net.gen.get_variables())
        d_op = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(train_out['d_loss'], var_list=net.disc.get_variables())
 
    global_step = tf.Variable(0, name='global_step', trainable=False)
    tf.add_check_numerics_ops()
    
    # Set up a supervisor
    step = 0
    gen_moving = []
    sample_moving = []
    sv = tf.train.Supervisor(logdir=None,
                             is_chief=True,
                             summary_op=None,
                             save_summaries_secs=30,
                             save_model_secs=60,
                             global_step=global_step)

    s = sv.prepare_or_wait_for_session(config=tf.ConfigProto())
    
    # Main training loop
    for _ in range(2000):
        
        # Get a real example batch and normalise to [-1, 1]
        batch = (mnist.train.next_batch(batch_size)[0] * 2) - 1

        # Update the discriminator
        for _ in range(disc_updates):
            noise = np.random.uniform(-1, 1, [batch_size, 100]).astype(np.float32)
            _ = s.run([d_op], feed_dict={xp: batch, z_in: noise, net.batch_size: batch_size})
            if train_out['clip_ops'] is not None:
                _ = s.run([train_out['clip_ops']])
        
        # Update the generator
        for _ in range(gen_updates):
            noise = np.random.uniform(-1, 1, [batch_size, 100]).astype(np.float32)
            (_, gc, dc) = s.run([g_op, train_out['disc_gen_logits'], train_out['disc_sample_logits']],
                            feed_dict={xp: batch, z_in: noise, net.batch_size: batch_size})
            
        # Collect the losses
        gen_moving.append(dc)
        sample_moving.append(gc)

        # Plot a sample and the loss every 20 steps
        if step % 20 == 0:
            plt.close()
            display.clear_output(wait=True)
            
            # Generate a sample
            sample = s.run(train_out['gen']['gen'], feed_dict={z_in: noise})
            
            f, axarr = plt.subplots(1, 2, figsize=(16,4))
            axarr[0].imshow(np.squeeze(sample[0,:,:]), cmap='gray', vmin=-1, vmax=1,)
            axarr[0].set_title('Sample')
            axarr[0].axis('off')
            
            axarr[1].plot(gen_moving, label='generated score')
            axarr[1].plot(sample_moving, label='sample score')
            axarr[1].set_title('Loss')
            axarr[1].axis('on')
            axarr[1].legend()
            display.display(plt.gcf())
            
        step += 1

# Close plot for cleanup
plt.close()