In [1]:
from datetime import datetime
from packaging import version
import os
import sys
import logging

In [2]:
sys.path.append("./metalearning")

In [3]:
%load_ext tensorboard

In [4]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from metalearning_tasks.associative_recall import AssociativeRecallTask
from models.model_builder import create_algorithmic_task_model

In [5]:
batch_size = 32
nb_reads = 1
controller_size = 100
memory_size =128
memory_dim = 20
num_layers = 1
bits_per_vector = 8
lr=0.001
cells = ['LSTM','NTMv2','LRUold','DNC']
save_dir='./associative_recall/associative_recall'
sizes = [6, 10, 15, 20]
models = {}
max_len = 20

In [6]:
test_error = [tf.keras.metrics.Mean(f'error_{i}', dtype=tf.float32) for i in sizes]

In [7]:
data_generator = AssociativeRecallTask(batch_size=batch_size,
                          max_items=max_len,
                          bits_per_vector=bits_per_vector,
                          lines_per_item=3)

for cell in cells:
  dir = f"{save_dir}/{cell}/{lr}/111"
  try:
    print(f'loading weights from {dir} for {cell}')
    model = create_algorithmic_task_model(2 + bits_per_vector,
                                      bits_per_vector,
                                      batch_size,
                                      cell)

    model.load_weights(dir + "/model.")
    models[cell] = model
    model.summary()
  except:
    logging.exception("failed to load weights")

loading weights from ./associative_recall/associative_recall/LSTM/0.001/111 for LSTM
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(32, None, 10)]          0         
_________________________________________________________________
tf.stop_gradient (TFOpLambda (32, None, 10)            0         
_________________________________________________________________
rnn (RNN)                    (32, None, 256)           273408    
_________________________________________________________________
dense (Dense)                (32, None, 8)             2056      
Total params: 275,464
Trainable params: 275,464
Non-trainable params: 0
_________________________________________________________________
loading weights from ./associative_recall/associative_recall/NTMv2/0.001/111 for NTMv2
Model: "model_1"
________________________________________________________________

In [8]:
models

{'LSTM': <tensorflow.python.keras.engine.functional.Functional at 0x7f4b2c1b5a60>,
 'NTMv2': <tensorflow.python.keras.engine.functional.Functional at 0x7f4b2c14e9d0>,
 'LRUold': <tensorflow.python.keras.engine.functional.Functional at 0x7f4b24515910>,
 'DNC': <tensorflow.python.keras.engine.functional.Functional at 0x7f4b245e0e50>}

In [9]:
def test(model, data_generator, sizes):
    test_error = [tf.keras.metrics.Mean(f'train_error_{s}', dtype=tf.float32) for s in sizes]
    print("Test Result\n")
    print('size', end='\t')
    for size in sizes:
        print('%d' % size, end='\t')
    print('')
    for i,s in enumerate(sizes):
      for ep in range(100):
              x, y, seq_len = data_generator.generate_batch(s)
              output = model(x)
              out_sig = tf.sigmoid(output)
              error = metric_accuracy(y, out_sig, -4)
              test_error[i](error)
    print('', end='\t')
    for accu in test_error:
        print('%.2f' % accu.result(), end='\t')
    return test_error

In [10]:
def metric_accuracy(labels, outputs, seq_len):
    sub_label = labels[:, seq_len+1:, :]
    rounded = tf.round(outputs[:, seq_len+1:, :])
    diff = tf.math.abs(sub_label-rounded)
    return tf.reduce_sum(diff)/sub_label.shape[0]

In [11]:
errors = {}
for name, model in models.items():
  err = test(model, data_generator, sizes)
  errors[name] = err

Test Result

size	6	10	15	20	
	3.80	8.91	9.95	10.27	Test Result

size	6	10	15	20	
	10.20	11.11	11.54	11.73	Test Result

size	6	10	15	20	
	5.92	9.16	10.63	11.10	Test Result

size	6	10	15	20	
	0.02	0.02	0.06	0.08	

In [12]:
models.keys()

dict_keys(['LSTM', 'NTMv2', 'LRUold', 'DNC'])

In [13]:
print('Model',end='\t')
for s in sizes:
  print(s,end='\t')
print()
for model, error in errors.items():
  print(model,end='\t')
  for err in error:
    print(float(err.result()),end='\t')
  print()


Model	6	10	15	20	
LSTM	3.8009374141693115	8.9140625	9.949999809265137	10.2709379196167	
NTMv2	10.195937156677246	11.112812042236328	11.541250228881836	11.725312232971191	
LRUold	5.923749923706055	9.158437728881836	10.628437042236328	11.1040620803833	
DNC	0.019687499850988388	0.019999999552965164	0.06031249836087227	0.08093749731779099	
