You might need to modify the third line in the code cell below, to make sure you cd to the actual directory which your ipynb file is located in.

**Caution**: due to the nature of this project's setup, everytime you want to rerun some code cell below, please click **Runtime -> Restart and run all**; this operation clears the computational graphs and the local variables but allow training and testing data that are already loaded from google drive to stay in the colab runtime space. Please do **not** do the following if you just wish to rerun code: click Runtime -> reset all runtimes, and then click Runtime -> Run all; it will remount your google drive, and remove the training and testing data already loaded in your colab runtime space. **Runtime -> Restart and run all** automatically avoids remounting the drive after the first time you run the notebook file; the loaded data can usually stay in your colab runtime space for many hours.

Loading the training and testing data after remounting your google drive takes 30 - 40 minutes.

In [1]:
# from google.colab import drive
# drive.mount("/content/gdrive/", force_remount=True)
# %cd gdrive/My Drive/Neural_Turing_Machine/NTM_small

In [1]:
from utils import OmniglotDataLoader, one_hot_decode, five_hot_decode
import tensorflow as tf
import argparse
import numpy as np
# %tensorflow_version 1.x
# print(tf.__version__)

Already implemented, no need to change.

This class is part of the training loop.

In [2]:
class NTMOneShotLearningModel():
  def __init__(self, model, n_classes, batch_size, seq_length, image_width, image_height,
                rnn_size, num_memory_slots, rnn_num_layers, read_head_num, write_head_num, memory_vector_dim, learning_rate):
    self.output_dim = n_classes

    # Note: the images are flattened to 1D tensors
    # The input data structure is of the following form:
    # self.x_image[i,j,:] = jth image in the ith sequence (or, episode)
    self.x_image = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, image_width * image_height])
    # Model's output label is one-hot encoded
    # The data structure is of the following form:
    # self.x_label[i,j,:] = one-hot label of the jth image in 
    #             the ith sequence (or, episode)
    self.x_label = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, self.output_dim])
    # Target label is one-hot encoded
    self.y = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, self.output_dim])
    
    # The dense layer for mapping controller output and retrieved
    # memory content to classification labels
    self.controller_output_to_ntm_output = tf.keras.layers.Dense(units=self.output_dim, use_bias=True)

    if model == 'LSTM':
      # Using a LSTM layer to serve as the controller, no memory
      def rnn_cell(rnn_size):
        return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
      cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(rnn_size) for _ in range(rnn_num_layers)])
      state = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
    
    # Initialize the controller model, including wiping its memory
    # Also, get the initial state of the MANN model
    
    self.state_list = [state]
    # Setup the NTM's output
    self.o = []
    
    # Now iterate over every sample in the sequence 
    for t in range(seq_length):
      output, state = cell(tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]], axis=1), state)
      # Map controller output (with retrieved memory) + current (offseted) label 
      # to the overall ntm's output with an affine operation
      # The output is the classification labels
      output = self.controller_output_to_ntm_output(output)
      output = tf.nn.softmax(output, axis=1)
      self.o.append(output)
      self.state_list.append(state)
    # post-process the output of the classifier
    self.o = tf.stack(self.o, axis=1)
    self.state_list.append(state)

    eps = 1e-8
    # cross entropy, between model output labels and target labels
    self.learning_loss = -tf.reduce_mean(  
        tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2])
    )
    
    self.o = tf.reshape(self.o, shape=[batch_size, seq_length, -1])
    self.learning_loss_summary = tf.summary.scalar('learning_loss', self.learning_loss)

    with tf.variable_scope('optimizer'):
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      self.train_op = self.optimizer.minimize(self.learning_loss)

The training and testing functions

In [None]:
def train(learning_rate, image_width, image_height, n_train_classes, n_test_classes, restore_training, \
         num_epochs, n_classes, batch_size, seq_length, num_memory_slots, augment, save_dir, model_path, tensorboard_dir):

  # We always use one-hot encoding of the labels in this experiment
  label_type = "one_hot"

  # Initialize the model
  model = NTMOneShotLearningModel(model=model_path, n_classes=n_classes,\
                    batch_size=batch_size, seq_length=seq_length,\
                    image_width=image_width, image_height=image_height, \
                    rnn_size=rnn_size, num_memory_slots=num_memory_slots,\
                    rnn_num_layers=rnn_num_layers, read_head_num=read_head_num,\
                    write_head_num=write_head_num, memory_vector_dim=memory_vector_dim,\
                    learning_rate=learning_rate)
  print("Model initialized")
  data_loader = OmniglotDataLoader(
      image_size=(image_width, image_height),
      n_train_classses=n_train_classes,
      n_test_classes=n_test_classes
  )
  print("Data loaded")
  # Note: our training loop is in the tensorflow 1.x style
  with tf.Session() as sess:
    if restore_training:
      saver = tf.train.Saver()
      ckpt = tf.train.get_checkpoint_state(save_dir + '/' + model_path)
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      saver = tf.train.Saver(tf.global_variables())
      tf.global_variables_initializer().run()
    train_writer = tf.summary.FileWriter(tensorboard_dir + '/' + model_path, sess.graph)
    print("1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tepoch\tloss")
    for b in range(num_epochs):
      # Test the model
      if b % 100 == 0:
        # Note: the images are flattened to 1D tensors
        # The input data structure is of the following form:
        # x_image[i,j,:] = jth image in the ith sequence (or, episode)
        # And the sequence of 50 images x_image[i,:,:] constitute
        # one episode, and each class (out of 5 classes) has around 10
        # appearances in this sequence, as seq_length = 50 and 
        # n_classes = 5, as specified in the code block below
        # See the details in utils.py, OmniglotDataLoader class
        x_image, x_label, y = data_loader.fetch_batch(n_classes, batch_size, seq_length,
                                  type='test',
                                  augment=augment,
                                  label_type=label_type)
        feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
        output, learning_loss = sess.run([model.o, model.learning_loss], feed_dict=feed_dict)
        merged_summary = sess.run(model.learning_loss_summary, feed_dict=feed_dict)
        train_writer.add_summary(merged_summary, b)
        accuracy = test(batch_size,seq_length, y, output)
        for accu in accuracy:
          print('%.4f' % accu, end='\t')
        print('%d\t%.4f' % (b, learning_loss))

      # Save model per 2000 epochs
      if b%2000==0 and b>0:
        saver.save(sess, save_dir + '/' + model_path + '/model.tfmodel', global_step=b)

      # Train the model
      x_image, x_label, y = data_loader.fetch_batch(n_classes, batch_size, seq_length, \
                                type='train',
                                augment=augment,
                                label_type=label_type)
      feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
      sess.run(model.train_op, feed_dict=feed_dict)

# Fill in this function. You might not need seq_length (the length of an episode)
# as an input, depending on your setup 
# Note: y is the true labels, and of shape (batch_size, seq_length, 5)
# output is the network's classification labels
def test(batch_size,seq_length, y, output):
  # Fill in
  decoded_y = one_hot_decode(y)
  decoded_out = one_hot_decode(output)
  correct = np.zeros(seq_length)
  total = np.zeros(seq_length)
  for i in range(batch_size):
    # ith episode
    yi = decoded_y[i,:]
    outi = decoded_out[i,:]
    count = {}
    for j in range(seq_length):
      if yi[j] not in count:
        count[yi[j]] = 0
      count[yi[j]] += 1
      total[count[yi[j]]] += 1
      if  yi[j] == outi[j]: 
        correct[count[yi[j]]] += 1
  return [float(correct[i]) / total[i] if total[i] > 0. else 0. for i in range(1, 11)]

In [None]:
restore_training = False
label_type = "one_hot"
n_classes = 5
seq_length = 50
augment = True
read_head_num = 4
batch_size = 16
num_epochs = 100000
learning_rate = 1e-3
rnn_size = 200
image_width = 20
image_height = 20
rnn_num_layers = 1
num_memory_slots = 128
memory_vector_dim = 40
shift_range = 1
write_head_num = 4
test_batch_num = 100
n_train_classes = 220
n_test_classes = 60
save_dir = './save/one_shot_learning'
tensorboard_dir = './summary/one_shot_learning'
model_path = 'LSTM'
train(learning_rate, image_width, image_height, n_train_classes, n_test_classes, restore_training, \
         num_epochs, n_classes, batch_size, seq_length, num_memory_slots, augment, save_dir, model_path, tensorboard_dir)


Instructions for updating:
This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').
Model initialized
Entered Dataloader
10.0% data loaded.
20.0% data loaded.
30.0% data loaded.
40.0% data loaded.
50.0% data loaded.
60.0% data loaded.
70.0% data loaded.
80.0% data loaded.
90.0% data loaded.
100.0% data loaded.
Data loaded
1st	2nd	3rd	4th	5th	6th	7th	8th	9th	10th	epoch	loss
0.1875	0.1875	0.1750	0.1250	0.1750	0.2051	0.1918	0.1364	0.2000	0.2553	0	80.6996
0.1500	0.2000	0.1500	0.2250	0.2278	0.2727	0.2105	0.1667	0.1852	0.2000	100	80.5079
0.2125	0.1875	0.2000	0.2000	0.1948	0.2568	0.1972	0.2000	0.2069	0.2128	200	80.4981
0.1750	0.1625	0.1500	0.2500	0.2308	0.1867	0.2740	0.2319	0.1346	0.3182	300	80.5099
0.2375	0.1500	0.2625	0.1875	0.1899	0.2436	0.2254	0.2586	0.3077	0.1489	400	80.4048
0.1875	0.1750	0.1625	0.1772	0.1266	0.2338	0.2083	0.1571	0.2542	0.212

0.1375	0.3000	0.4304	0.5190	0.4615	0.4231	0.4400	0.4444	0.5091	0.3902	9100	63.5443
0.1500	0.3375	0.3625	0.4000	0.4286	0.4133	0.4143	0.3175	0.4912	0.4894	9200	67.2830
0.1625	0.3000	0.3750	0.5063	0.4051	0.4737	0.3562	0.4286	0.3396	0.4565	9300	66.6240
0.1625	0.4125	0.4000	0.4875	0.4430	0.4737	0.5342	0.5000	0.4464	0.4348	9400	63.8648
0.2375	0.2875	0.4500	0.4625	0.4625	0.5195	0.4658	0.5147	0.3448	0.5349	9500	62.5310
0.1500	0.3875	0.5375	0.4051	0.4557	0.4545	0.4521	0.4000	0.4364	0.5000	9600	63.6735
0.2125	0.3125	0.3875	0.4000	0.4359	0.4211	0.5205	0.4308	0.4182	0.5532	9700	64.8643
0.1625	0.4125	0.4500	0.4430	0.4805	0.4658	0.4225	0.5231	0.4035	0.4783	9800	61.4886
0.2250	0.3875	0.4000	0.4375	0.4051	0.4051	0.4211	0.5000	0.3962	0.4222	9900	63.3724
0.1250	0.4125	0.4000	0.4500	0.4125	0.4605	0.4110	0.4688	0.3559	0.4583	10000	65.8887
0.2000	0.3375	0.4625	0.5250	0.4750	0.5443	0.4722	0.4603	0.4386	0.5227	10100	60.2589
0.2250	0.2000	0.3750	0.3625	0.5063	0.4211	0.4143	0.5000	0.3390	0.3542	10200	67.8942
0

0.2750	0.7000	0.8125	0.8228	0.7215	0.7105	0.8714	0.7705	0.7736	0.7561	18900	37.1043
0.2250	0.5375	0.6375	0.6375	0.6709	0.7662	0.7808	0.7424	0.8103	0.7234	19000	42.4656
0.2750	0.6125	0.6962	0.7051	0.7436	0.7867	0.7746	0.7612	0.8621	0.8182	19100	37.9216
0.2750	0.6125	0.6875	0.7013	0.7733	0.7534	0.7463	0.8095	0.7759	0.8085	19200	39.8701
0.3125	0.6375	0.6875	0.6750	0.7625	0.7532	0.7361	0.8060	0.7692	0.8537	19300	38.4707
0.2000	0.5000	0.5875	0.7250	0.6750	0.6923	0.7123	0.6818	0.7544	0.8605	19400	43.9810
0.1625	0.6125	0.6500	0.7125	0.7564	0.8000	0.8056	0.7463	0.7458	0.7143	19500	37.7122
0.2375	0.5500	0.7875	0.7125	0.7468	0.6974	0.7534	0.8413	0.8333	0.8182	19600	38.7144
0.1875	0.6250	0.7375	0.7500	0.7342	0.7973	0.7612	0.7692	0.7167	0.7000	19700	38.7801
0.2625	0.5750	0.6750	0.7125	0.7436	0.7662	0.7083	0.7344	0.7818	0.7917	19800	41.2507
0.2375	0.7500	0.7875	0.7500	0.6625	0.8816	0.7826	0.7969	0.7736	0.8163	19900	34.9402
0.2625	0.6125	0.7875	0.8000	0.7750	0.6883	0.8451	0.7541	0.7885	0.8261	20000	