In [None]:
from __future__ import print_function
from __future__ import absolute_import

import os
import sys
import shutil
import tensorflow as tf
import scipy
from sklearn import metrics
import numpy as np

from matplotlib import pyplot as plt
from IPython import display

%matplotlib inline

In [None]:
# Import the Graph CNN model from the DLTK models
import graph_utils as utils
from dltk.models.graphical.cgcnn import CGCNN

# Set the CUDA_VISIBLE_DEVICES environmental variable to GPU ids to compute on
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# Create a save path for log files and model parameters
save_path =  '/tmp/MNIST_graph_cnn'
shutil.rmtree(save_path, ignore_errors=True)
tf.logging.set_verbosity(tf.logging.ERROR)

# Load the MNIST data via tf.examples
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../../data/MNIST_data', one_hot=False)

In [None]:
# Graph parameters
coarsening_levels = 4
number_edges = 8
metric = 'euclidean'

# Network parameters
num_classes = max(mnist.train.labels) + 1  # number of classes

filters = [32, 64]
K_order = [25, 25]
strides = [4, 4]
num_fc = [512, num_classes]

In [None]:
def grid_graph(m, corners=False):
    ''' DOCSTRING PLEASE! '''
    z = utils.grid(m)
    dist, idx = utils.distance_sklearn_metrics(z, k=number_edges, metric=metric)
    A = utils.adjacency(dist, idx)

    # Connections are only vertical or horizontal on the grid.
    # Corner vertices are connected to 2 neightbors only.
    if corners:
        import scipy.sparse
        A = A.toarray()
        A[A < A.max()/1.5] = 0
        A = scipy.sparse.csr_matrix(A)
        print('{} edges'.format(A.nnz))

    print("{} > {} edges".format(A.nnz//2, number_edges*m**2//2))
    return A

A = grid_graph(28, corners=False)
A = utils.replace_random_edges(A, 0)
graphs, perm = utils.coarsen(A, levels=coarsening_levels, self_connections=False)
L = [utils.laplacian(A, normalized=True) for A in graphs]
del A

In [None]:
# Transform data to a GCN compatible format
train_data = mnist.train.images.astype(np.float32)
test_data = mnist.test.images.astype(np.float32)
test_labels = mnist.test.labels

test_data = utils.perm_data(test_data, perm)

# Build the GCNN network graph
net = CGCNN(L, filters, K_order, strides, num_fc, bias='b1', pool='mpool', dropout=0.5)

# Create placeholders to feed input data during execution
batch_size = 100
M_0 = L[0].shape[0]
xp = tf.placeholder(tf.float32, (100, M_0), 'data')
yp = tf.placeholder(tf.int32, (100), 'labels')

# Compute the mean categorical crossentropy as a loss function
logits_ = net(xp)['logits']
labels_ = yp
crossentropy_ = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_, labels=labels_)
loss_ = tf.reduce_mean(crossentropy_, name='crossentropy')

# Employ an ADAM optimiser to minimise the crossentropy loss during training
train_op = tf.train.MomentumOptimizer(0.02, 0.9).minimize(loss_)

In [None]:
# Create additional ops to visualise the network output and track the training steps
y_hat_ = net(xp, is_training=False)['y_']
val_acc_ = tf.reduce_mean(tf.cast(tf.equal(tf.cast(yp, tf.int32), tf.cast(y_hat_, tf.int32)), tf.float32))
global_step = tf.Variable(0, name='global_step', trainable=False)

def predict(data, labels):
    acc = 0
    size = (data.shape[0])
    
    for begin in range(0, size, batch_size):
        end = begin + batch_size
        end = min([end, size])

        batch_data = np.zeros((batch_size, data.shape[1]))
        tmp_data = data[begin:end,:]
        if type(tmp_data) is not np.ndarray:
            tmp_data = tmp_data.toarray()  # convert sparse matrices
        batch_data[:end-begin] = tmp_data

        val_acc = s.run(val_acc_, {xp: batch_data, yp: batch_labels})
        acc += val_acc

    return acc * batch_size / size

In [None]:
# Set up a supervisor to continuously save and log the training progress, handle queues and initialise variables 
step = 0
loss_moving = []  
acc_moving = []  
sv = tf.train.Supervisor(logdir=save_path,
                         is_chief=True,
                         summary_op=None,
                         save_summaries_secs=30,
                         save_model_secs=60,
                         global_step=global_step)

s = sv.prepare_or_wait_for_session(config=tf.ConfigProto())

In [None]:
# Training loop  
while not sv.should_stop():
    
    # Get a batch of training input pairs of x (image) and y (label)
    batch = mnist.train.next_batch(100)
    batch_data = utils.perm_data(batch[0], perm)
    batch_labels = batch[1]
    
    # Run the training op and the loss
    _, logits, loss = s.run([train_op, logits_, loss_], feed_dict={xp: batch_data, yp: batch_labels})
    loss_moving.append(loss)    
            
    # Visualise all inputs, outputs and losses during each training step
    if step % 20 == 0:
        
        # Compute the validation accuracy
        val_acc = predict(test_data, test_labels)
        acc_moving.append(val_acc)
    
        plt.close()
        f, axarr = plt.subplots(1, 3, figsize=(16,4))
        
        axarr[0].imshow(np.reshape(batch[0], [-1, 28, 28])[-1], cmap='gray', vmin=0, vmax=1)
        axarr[0].set_title('Input x; Prediction = {}; Truth = {};'.format(np.argmax(logits[-1,]), batch[1][-1,]))
        axarr[0].axis('off')
        
        axarr[1].plot(loss_moving)
        axarr[1].set_title('Training loss')
        axarr[1].axis('on')
        
        axarr[2].plot(acc_moving)
        axarr[2].set_title('Test acc')
        axarr[2].axis('on')

        display.clear_output(wait=True)
        display.display(plt.gcf())

    step += 1