In [2]:
from datetime import datetime
from packaging import version
import os
import sys
import logging

In [3]:
sys.path.append("./metalearning")

In [4]:
%load_ext tensorboard

In [5]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from metalearning_tasks.variable_shot_classification import DatasetGenerator
from models.model_builder import create_string_model

In [6]:
exp_no = 3
_iterations = 100000
batch_size = _batch_size = 32
nb_classes = _nb_classes = 5
nb_samples_per_class = _nb_samples_per_class = 10
_input_height = _input_width = 20
img_size = (_input_height, _input_width)
colorspace = 'L'
channels = {'RGB':3, 'L':1}
input_size = _input_height * _input_width * channels[colorspace]
nb_reads = _nb_reads = 1
controller_size = _controller_size = 100
memory_size = _memory_locations = 128
memory_dim = ar_memory_word_size = 20
dataset = 'omniglot'
splits=[1200,200,0]

cells = {'LSTM':'LSTM','LRUold':'DNC','DNC':'DNC'}
save_dir='./varshot'
sizes = [5, 10, 15, 20, 25, 50, 100]
models = {}


In [7]:
test_error = [tf.keras.metrics.Mean(f'error_{i}', dtype=tf.float32) for i in sizes]

In [8]:
ds_root = f'./datasets/Records/{dataset}'
data_generator = DatasetGenerator(data_folder=ds_root,
                                  splits=splits,
                                  nb_samples_per_class=nb_samples_per_class,
                                  img_size=img_size,
                                  colorspace=colorspace,
                                  pre_scale=(60,60),
                                  augment=True
                                )




In [14]:
for folder,cell in cells.items():
  dir = f"{save_dir}/{folder}/"
  try:
    print(f'loading weights from {dir} for {cell}')
    model = create_string_model(
                                input_size+25,
                                _batch_size,
                                cell)

    model.load_weights(dir + "/model.")
    models[folder] = model
    model.summary()
  except:
    logging.exception("failed to load weights")

loading weights from ./varshot/LSTM/ for LSTM

Two checkpoint references resolved to different objects (<tensorflow.python.keras.layers.recurrent.RNN object at 0x7f7ea9955fd0> and <tensorflow.python.keras.layers.core.TFOpLambda object at 0x7f7e715fb520>).

Two checkpoint references resolved to different objects (<tensorflow.python.keras.layers.core.Dense object at 0x7f7e715fb820> and <tensorflow.python.keras.layers.recurrent.RNN object at 0x7f7ea9955fd0>).
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(32, None, 425)]    0                                            
__________________________________________________________________________________________________
tf.stop_gradient_3 (TFOpLambda) (32, None, 425)      0           input_4[0][0]                    
___________________________


Two checkpoint references resolved to different objects (<tensorflow.python.keras.layers.core.Dense object at 0x7f7e714c5df0> and <tensorflow.python.keras.layers.recurrent.RNN object at 0x7f7e715160a0>).
Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(32, None, 425)]    0                                            
__________________________________________________________________________________________________
tf.stop_gradient_5 (TFOpLambda) (32, None, 425)      0           input_6[0][0]                    
__________________________________________________________________________________________________
rnn_5 (RNN)                     (32, None, 25)       230313      tf.stop_gradient_5[0][0]         
_____________________________________________________________________________________

In [15]:
models

{'LSTM': <tensorflow.python.keras.engine.functional.Functional at 0x7f7e71584880>,
 'DNC': <tensorflow.python.keras.engine.functional.Functional at 0x7f7e7146de20>,
 'LRUold': <tensorflow.python.keras.engine.functional.Functional at 0x7f7e714e6490>}

In [30]:
def labelstr(pentits):
    chars=['A','B','C','D','E']
    strchars = [chars[pentits[idx]] for idx in range(len(pentits))]
    return str(strchars)


In [87]:
def test(model, data_generator, sizes):
    test_error = [[tf.keras.metrics.Mean(f'error_{s}classes_inst{i}', dtype=tf.float32) for i in [1,2,5,10]] for s in sizes]
    print('size', end='\t')
    for size in sizes:
        print('%d' % size, end='\t\t\t')
    print('')
    print('\t', end='')
    for size in sizes:
        for instance in [1,2,5,10]:
            print(f'{instance}    ', end='')
        print(f'', end='\t')
    for i,s in enumerate(sizes):
      for ep in range(10):
              image, label = data_generator.generate_batch("train",
                                                           s,
                                                           32)
              splits = tf.split(label,5,axis=-1)
              one_hot_targets = [tf.one_hot(split, _nb_classes, axis=-1) for split in splits]
              labels_onehot = tf.squeeze(tf.concat(one_hot_targets,axis=-1))
              offset_target_var = tf.concat([tf.zeros_like(tf.expand_dims(
              labels_onehot[:, 0], 1)), labels_onehot[:, :-1]], axis=1)
              ntm_input = tf.concat([image, offset_target_var], axis=2)
              output = model(ntm_input)
              accuracy = metric_accuracy(s, 10, label, output)
              for ixx, inst in enumerate([1,2,5,9]):
                  test_error[i][ixx](accuracy[inst])
    print('', end='\t')
    for size in test_error:
        for ixx in range(4):
            print('%.1f' % (size[ixx].result()*100.0), end=' ')
        print('\t')
    print('')
            
    return test_error

In [35]:
def metric_accuracy(_nb_classes,_nb_samples_per_class,labels, outputs):
    seq_length = _nb_classes * _nb_samples_per_class
    outputs_split = tf.split(outputs,5,axis=-1)
    outputs = [tf.argmax(split, axis=-1) for split in outputs_split]
    outputs = np.stack(outputs,axis=-1)

    correct = [0] * seq_length
    total = [0] * seq_length
    for i in range(np.shape(labels)[0]):
        label = labels[i]
        output = outputs[i]
        class_count = {}
        for j in range(seq_length):
            label_str = labelstr(label[j])
            output_str = labelstr(output[j])
            class_count[label_str] = class_count.get(label_str, 0) + 1
            total[class_count[label_str]] += 1
            if label_str == output_str:
                correct[class_count[label_str]] += 1
    return [float(correct[i]) / total[i] if total[i] > 0. else 0.
            for i in range(1, _nb_samples_per_class + 1)]

In [36]:
models.keys()

dict_keys(['LSTM', 'DNC', 'LRUold'])

In [88]:
errors = {}
for name, model in models.items():
  err = test(model, data_generator, sizes)
  errors[name] = err

size	5			10			15			20			25			50			100			
	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    		23.2 31.2 37.8 49.3 	
13.9 16.4 19.5 25.1 	
8.0 10.0 12.6 15.9 	
6.5 8.0 9.6 12.1 	
4.7 6.3 7.2 9.6 	
2.1 2.7 3.2 4.1 	
1.0 1.0 1.2 1.6 	

size	5			10			15			20			25			50			100			
	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    		80.9 88.5 93.1 95.6 	
71.8 79.1 86.9 90.7 	
63.2 71.6 80.0 86.5 	
57.5 66.4 76.8 82.1 	
52.2 61.9 70.5 76.7 	
37.2 43.9 49.2 55.2 	
23.6 26.9 28.5 32.5 	

size	5			10			15			20			25			50			100			
	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    		77.3 87.4 92.1 95.3 	
69.2 77.8 86.8 90.2 	
63.2 73.1 82.7 87.2 	
57.5 68.1 78.0 83.7 	
51.3 62.0 72.5 81.5 	
35.4 

In [None]:
print('Model',end='\t')
for s in sizes:
  print(s,end='\t')
print()
for model, error in errors.items():
  print(model,end='\t')
  for err in error:
    print(float(err.result()),end='\t')
  print()


In [83]:
    print('size', end='\t')
    for size in sizes:
        print('%d' % size, end='\t\t\t')
    print('')
    print('\t', end='')
    for size in sizes:
        for instance in [1,2,5,10]:
            print(f'{instance}    ', end='')
        print(f'', end='\t')

size	5			10			15			20			25			50			100			
	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	1    2    5    10    	