In [10]:
from recurrent_controller import RecurrentController
from dnc.dnc import DNC
import tensorflow as tf
import numpy as np
import pickle
import sys
import os

In [11]:
def llprint(message):
    sys.stdout.write(message)
    sys.stdout.flush()

def load(path):
    return pickle.load(open(path, 'rb'))

def onehot(index, size):
    vec = np.zeros(size, dtype=np.float32)
    vec[index] = 1.0
    return vec

def prepare_sample(sample, target_code, word_space_size):
    input_vec = np.array(sample[0]['inputs'], dtype=np.float32)
    output_vec = np.array(sample[0]['inputs'], dtype=np.float32)
    seq_len = input_vec.shape[0]
    weights_vec = np.zeros(seq_len, dtype=np.float32)

    target_mask = (input_vec == target_code)
    output_vec[target_mask] = sample[0]['outputs']
    weights_vec[target_mask] = 1.0

    input_vec = np.array([onehot(code, word_space_size) for code in input_vec])
    output_vec = np.array([onehot(code, word_space_size) for code in output_vec])

    return (
        np.reshape(input_vec, (1, -1, word_space_size)),
        np.reshape(output_vec, (1, -1, word_space_size)),
        seq_len,
        np.reshape(weights_vec, (1, -1, 1))
    )

In [13]:
ckpts_dir = './checkpoints/'
lexicon_dictionary = load('./data/en-10k/lexicon-dict.pkl')
target_code = lexicon_dictionary["-"]
test_files = []

for entryname in os.listdir('./data/en-10k/test/'):
    entry_path = os.path.join('./data/en-10k/test/', entryname)
    if os.path.isfile(entry_path):
        test_files.append(entry_path)

In [21]:
graph = tf.Graph()
with graph.as_default():
    with tf.Session(graph=graph) as session:
        
        ncomputer = DNC(
            RecurrentController,
            input_size=len(lexicon_dictionary),
            output_size=len(lexicon_dictionary),
            max_sequence_length=1920,
            memory_words_num=256,
            memory_word_size=32,
            memory_read_heads=4,
        )
        
        ncomputer.restore(session, ckpts_dir, 'step-20842')
        
        outputs, _ = ncomputer.get_outputs()
        softmaxed = tf.nn.softmax(outputs)
        
        all_tasks_results = []
        for test_file in test_files:
            test_data = load(test_file)
            task_name = os.path.basename(test_file)
            counter = 0
            results = []
            
            llprint("%s ... %d/%d" % (task_name, counter, len(test_data)))
            
            for story in test_data:
                target_mask = (np.array(story['inputs']) == target_code)
                desired_answers = np.array(story['outputs'])
                input_vec, _, seq_len, _ = prepare_sample([story], target_code, len(lexicon_dictionary))
                softmax_output = session.run(softmaxed, feed_dict={
                        ncomputer.input_data: input_vec,
                        ncomputer.sequence_length: seq_len
                })
                softmax_output = np.squeeze(softmax_output, axis=0)
                given_answers = np.argmax(softmax_output[target_mask], axis=1)
                grades = (given_answers == desired_answers)
                results.extend(grades)
                
                counter += 1
                llprint("\r%s ... %d/%d" % (task_name, counter, len(test_data)))
                
            error_rate = 1. - np.mean(results)
            all_tasks_results.append(error_rate)
            llprint("\r%s ... %.3f%% Error Rate.\n" % (task_name, error_rate * 100))
        
        print "Mean Error Rate: %.3f%" %  (np.mean(all_tasks_results) * 100)
        print "Failed Tasks (> 5%): %d" % (np.sum(all_tasks_results > 0.05))

qa14_time-reasoning_test.txt.pkl ... 0/200



qa14_time-reasoning_test.txt.pkl ... 73.600% Error Rate.
qa4_two-arg-relations_test.txt.pkl ... 31.900% Error Rate.
qa8_lists-sets_test.txt.pkl ... 47.895% Error Rate.
qa2_two-supporting-facts_test.txt.pkl ... 66.500% Error Rate.
qa9_simple-negation_test.txt.pkl ... 32.200% Error Rate.
qa15_basic-deduction_test.txt.pkl ... 75.425% Error Rate.
qa17_positional-reasoning_test.txt.pkl ... 8/125

KeyboardInterrupt: 