# Visualizing NTM performance on copy task

In [None]:
import os
import sys
import itertools
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable

from ntm import NTM
from recurrent_controller import RecurrentController

%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 15)

In [None]:
def generate_data(batch_size, length, size):
    input_data = np.zeros((batch_size, 2 * length + 2, size), dtype=np.float32)
    target_output = np.zeros((batch_size, 2 * length + 2, size), dtype=np.float32)

    sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 2))
    input_data[:, 0, 0] = 1
    input_data[:, 1:length+1, 1:size-1] = sequence
    input_data[:, length+1, -1] = 1  # the end symbol
    target_output[:, length + 2:, 1:size-1] = sequence

    return input_data, target_output

def llprint(message):
    sys.stdout.write(message)
    sys.stdout.flush()

In [None]:
def binary_cross_entropy(predictions, targets):
    return tf.reduce_mean(-1 * targets * tf.log(predictions) - (1 - targets) * tf.log(1 - predictions))

In [None]:
def visualize_op(input_series, ntm_output, memory_view):
    gs = gridspec.GridSpec(20, 1, hspace=0)

    ww_strip = np.squeeze(memory_view['write_weightings'])
    rw_strip = np.squeeze(memory_view['read_weightings'])
    colored_write = np.zeros((ww_strip.shape[0], ww_strip.shape[1], 3))
    colored_read = np.zeros((rw_strip.shape[0], rw_strip.shape[1], 3))
    for i in range(ww_strip.shape[0]):
        for j in range(ww_strip.shape[1]):
            colored_read[i, j] = [rw_strip[i,j], 0., 0.]
            colored_write[i, j] = [0., ww_strip[i,j], 0.]


    iax = plt.subplot(gs[0:5, 0])
    oax = plt.subplot(gs[5:10, 0])
    memax = plt.subplot(gs[10:, 0])

    iax.grid(True, color='gray')
    oax.grid(True, color='gray')
    memax.grid(True, color='gray', axis='x')
        
    iax.imshow(np.squeeze(input_series.T), cmap=plt.cm.gray, interpolation='nearest')
    iax.set_ylabel("Inputs")
    iax.set_yticks([])

    oax.imshow(np.squeeze(ntm_output.T), cmap=plt.cm.gray, interpolation='nearest')
    oax.set_ylabel("Outputs")
    oax.set_yticks([])

    memax.imshow(np.transpose(colored_write + colored_read, [1, 0, 2]), interpolation='nearest')
    memax.set_ylabel("Memory Location")
    write_legend = mpatches.Rectangle((1,1), 1, 1, color='green', label='Write Head')
    read_legend = mpatches.Rectangle((1,1), 1, 1, color='red', label='Read Head')
    memax.legend(bbox_to_anchor=(0.21, -0.1), handles=[write_legend, read_legend])
    
    return colored_write, colored_read

In [None]:
def hamming_distance(s1, s2):
    """Return the Hamming distance between equal-length sequences"""
    if len(s1) != len(s2):
        raise ValueError("Undefined for sequences of unequal length")
    return sum(el1 != el2 for el1, el2 in zip(s1, s2))

## Loading and running NTM model

Trained on sequences of length up to 10.

In [None]:
testing_seq_length = 30    # Sequence length for testing
testing_runs = 10          # Number of testing runs

dist = []
losses = []
matches = []
inputs = []
outs = []
views = []

ckpts_dir = os.path.join(os.path.dirname("__file__"), 'checkpoints')

tf.reset_default_graph()

with tf.Session() as session:

    turing_machine = NTM(
        RecurrentController,
        input_size = 10,
        output_size = 10,
        memory_locations = 128,
        memory_word_size = 20,
        memory_read_heads = 1,
        shift_range = 1,
        batch_size = 1
    )
    
    outputs, memory_views = turing_machine.get_outputs()
    squashed_output = tf.clip_by_value(tf.sigmoid(outputs), 1e-6, 1. - 1e-6)
    loss = binary_cross_entropy(squashed_output, turing_machine.target_output)
    
    session.run(tf.global_variables_initializer())
    
    # Restoring the provided model
    turing_machine.restore(session, ckpts_dir, 'step-100000') 
    
    for i in range(testing_runs):
        input_data, target_output = generate_data(1, testing_seq_length, 10)

        loss_value,out,mem = session.run([
            loss,
            squashed_output,
            memory_views
        ], feed_dict={
            turing_machine.input_data: input_data,
            turing_machine.target_output: target_output,
            turing_machine.sequence_length: testing_seq_length*2 + 2
        })
        
        dist.append(hamming_distance(np.reshape(np.round(out),(1,-1)).tolist()[0],np.reshape(target_output,(1,-1)).tolist()[0]))
        losses.append(loss_value)
        inputs.append(input_data)
        outs.append(out)
        views.append(mem)
        matches.append(np.allclose(target_output, np.around(out)))
        
print("Avg. Accuracy: %.4f" % (np.mean(matches)))
print("Avg. Loss: %.4f" % (np.mean(losses)))
print("Avg. Dist: %.4f" % (np.mean(dist)))

## Plotting the best output

In [None]:
best_indx = np.argmin(dist)
print('Hamming distance: %d' % dist[best_indx])
best_input, best_output, best_memview = inputs[best_indx], outs[best_indx], views[best_indx]

a = visualize_op(best_input, best_output, best_memview)