In [None]:
from __future__ import print_function
from __future__ import absolute_import

import os
import sys
import tensorflow as tf
import numpy as np

import datetime

from matplotlib import pyplot as plt
from IPython import display

# iPython specific
%matplotlib inline

In [None]:
# DLTK imports
from dltk.models.autoencoder.convolutional_autoencoder import ConvolutionalAutoencoder
from dltk.core.modules import *

# MNIST data import
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../../data/MNIST_data', one_hot=False)
save_path =  '/tmp/AE_MNIST'
os.system("rm -rf " + save_path)

# Set environmental variables to select CUDA device
os.environ["CUDA_VISIBLE_DEVICES"]='0'
tf.logging.set_verbosity(tf.logging.ERROR)

with tf.Graph().as_default():
    
    # Define inputs
    xp = tf.placeholder(tf.float32, shape=[None, 784])

    x_in = tf.reshape(xp, [-1, 28, 28, 1])

    # Build a convolutional autoencoder
    net = ConvolutionalAutoencoder(strides=((1, 1), (2, 2), (1, 1), (2, 2)))
    
    # Build the computation graph with given inputs
    train_out = net(x_in)
    
    # Define a loss function
    loss_ = mse(tf.nn.sigmoid(train_out['x_']), x_in)
    
    # Build test network for reconstruction
    test_out = net(x_in, is_training=False)
    test_rec = tf.nn.sigmoid(test_out['x_'])
    
    # Use ADAM as an optimiser
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = tf.train.AdamOptimizer().minimize(loss_) 
    tf.add_check_numerics_ops()
    
    # Set up the supervisor
    step = 0
    loss_moving = []  
    sv = tf.train.Supervisor(logdir=None,
                             is_chief=True,
                             summary_op=None,
                             save_summaries_secs=30,
                             save_model_secs=60,
                             global_step=global_step)

    s = sv.prepare_or_wait_for_session(config=tf.ConfigProto())
    
    # Main training loop
    for _ in range(5000):
        
        # Get training example batch
        batch = mnist.train.next_batch(100)

        # Run the training op
        (_, c) = s.run([train_op, loss_], feed_dict={xp: batch[0]})
        loss_moving.append(c)

        # Evaluate and visualise the results
        if step % 20 == 0:
            plt.close()
            display.clear_output(wait=True)
            
            # Generate a reconstruction and hidden representations for unseen samples
            (x, x_, hidden) = s.run([x_in, test_rec, test_out['hidden']],
                                    feed_dict={xp: mnist.test.images})

            f, axarr = plt.subplots(2, 2, figsize=(10,10))
            axarr[0, 0].imshow(np.squeeze(x[0,:,:,0]), cmap='gray', vmin=0, vmax=1,)
            axarr[0, 0].set_title('Input: x')
            axarr[0, 0].axis('off')

            axarr[0, 1].imshow(np.squeeze(x_[0,:,:,0]), cmap='gray', vmin=0, vmax=1,)
            axarr[0, 1].set_title('Reconstruction: x_')
            axarr[0, 1].axis('off')
            
            for lbl in range(10):
                selector = (mnist.test.labels == lbl)
                axarr[1, 0].plot(hidden[selector, 0], hidden[selector, 1], 'x', label='{}'.format(lbl))
                axarr[1, 0].set_title('Hidden')
                axarr[1, 0].axis('off')
                axarr[1, 0].legend(numpoints=1)
            
            axarr[1, 1].plot(loss_moving)
            axarr[1, 1].set_title('Loss')
            axarr[1, 1].axis('on')
            
            display.display(plt.gcf())
            
        step += 1
        
# Close the plot for cleanup
plt.close()