In [0]:
import zipfile
with zipfile.ZipFile('NTM_Student.zip', 'r') as zip_ref:
    zip_ref.extractall()

In [3]:
from utils import OmniglotDataLoader, one_hot_decode, five_hot_decode
import tensorflow as tf
import argparse
import numpy as np
%tensorflow_version 1.x
print(tf.__version__)


1.15.0


The following class `MANNCell` is the core of the memory-augmented neural network (MANN). You will implement the main parts of it in Tensorflow 2.0.

Before any technical discussion of how the MANNCell should operate, let us look at what it should do on a general level. Suppose we have an input batch of 16 episodes of image samples, with each episode being of equal length of 50. Based on the design of the rest of the project (which we have already implemented for you), MANNCell should be called 50 times, each time having 16 input samples (along with the offseted labels), and outputting 16 output labels. More specifically, the MANNCell should produce classification labels $[\hat{y}_0^t, ..., \hat{y}_{15}^t]$ for all 16 iteration-$t$ image samples batch $[x^t_0+\text{null}, x^t_1+y_0^t, ..., x^t_{15}+y_{14}^t]$ ("+" means concatenation) every time it is called; for your information, it is the class NTMOneShotLearningModel (already implemented below) that actually calls MANNCell 50 times. Your job is to make sure that at a single iteration $t$ (where $t=0,1,2,...,49$), MANNCell correctly parses the input arguments, produce the correct read and write weights $w^r_t, w^w_t$, correctly retrieve from and write to the memory to form $M_t$, and use the right material to get the logits for classification (they will be used for computing the labels and cross-entropy values in NTMOneShotLearningModel), and return the right states that will be used in the next iteration $t+1$. 

Let us look at the input arguments of the method `call(self, inputs, states)`  of this class first:
*   The `inputs` variable shall have the following shape: 
    `(batch_size, image_size+num_classes)`. 
  *   It corresponds to the $[x^t_0+\text{null}, x^t_1+y^t_0, ..., x^t_{15}+y^t_{14}]$ above, for some iteration $t=0,1,...,49$.
  *   `inputs[p,:]` is the $p$-th image in the batch `inputs` (note that the images are flattened to 1D tensors, and the labels are one-hot encoded).
*   The `states` variable is a dictionary that has the following set of keys:`{'controller_state', 'read_vector_list', 'w_r_list', 'w_u', 'M'}`
  *   `controller_state` is the state of the controller in iteration $t-1$; if $t-1 < 0$, then it is just zero-filled. As it is an LSTM cell, `controller_state` is of the form `[(batch_size, rnn_size),(batch_size, rnn_size)]` (technically speaking its shape is `(2, batch_size, rnn_size)`). The two `(batch_size, rnn_size)`-shaped entries in it correspond to the cell state and the hidden state of the LSTM. We will mostly be treating the LSTM controller as a black-box in this project, so we do not need to pay much attention to the details of its states. If interested, you can read about the LSTM cell's technical details in [tf.keras.layers.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell).
  *   `read_vector_list` is the list of read vectors $r_{t-1}$ which we obtained in the previous iteration $t-1$ in the episode; if $t-1 < 0$, then the read vector list is initialized to be an arbitrary one-hot vector. It is of the shape `(head_num, batch_size, memory_vector_dim)`. Basically, `read_vector_list[i,p,:]` is the $(t-1)$-th-iteration read vector of the $i$-th read head for the $p$-th input sample in the batch.
  *   `w_r_list` is the list of read weights $w^r_{t-1}$ which we obtained in the previous iteration $t-1$ in the episode; if $t-1 < 0$, then the read weights list is initialized to be an arbitrary one-hot vector. It is of the shape `(head_num, batch_size, num_memory_slots)`. Basically, `w_r_list[i,p,:]` is the $(t-1)$-th-iteration read weight of the $i$-th read head for the $p$-th input sample in the batch.
  *   `w_u` is the list of memory usage weights $w^u_{t-1}$ which we obtained in the previous iteration $t-1$ in the episode; if $t-1 < 0$, then the usage weights list is is initialized to be an arbitrary one-hot vector. It is of the shape `(batch_size, num_memory_slots)`. Basically, `w_u[p,:]` is the $(t-1)$-th-iteration memory usage weight of the $p$-th input sample in the batch.
  *   `M` is the memory content from the previous iteration $t-1$; if $t-1 < 0$, then the memory is just zero-filled. It is of shape `(batch_size, num_memory_slots, memory_vector_dim)`. Basically, `M[p,j,:]` is the $j$-th memory vector in the memory block for the $p$-th sample in the batch from iteration $t-1$, and `M[p,:,:]` is the memory block for the $p$-th sample in the batch, where the memory block is a 2D structure that has `num_memory_slots` memory vectors, each vector of length `memory_vector_dim`.




    




---



Now let us look at some of the technical details of the MANNCell. First, we discuss the main ingredients of the MANNCell, and initialization of the relevant units.
*   The input arguments of the class initialization method `__init__` have already been specified, they will be used to initialize relevant structures in the class.
*   `self.controller`: this is the controller of the MANN cell that is responsible for interfacing with the memory $M$. We recommend using `tf.keras.layers.LSTMCell` with `units=rnn_size` for initialization. For its technical details, see [tf.keras.layers.LSTMCell
](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell).
*   `self.controller_output_to_read_keys`, `self.controller_output_to_write_keys`, `self.controller_output_to_alphas`: the LSTM controller's output structure (we will discuss what its inputs should be later) is of the form [controller_output, controller_cell_and_hidden_states]. We need a mapping that maps the controller_output to the read keys, write keys and the interpolation coefficient $\alpha_t$'s, which will then be used for interacting with the memory. Three `tf.keras.layers.Dense` layers (one for producing read keys, one for write keys, one for the $\alpha_t$'s) are sufficient, though you are welcome to try out more complicated structures. 
  *  **Note**: each access to memory involves `head_num` number of heads, if you wish, you could just initialize `self.controller_output_to_read_keys` with `units=self.memory_vector_dim*self.head_num` and apply `tf.split` to the output of the dense layer along `axis=1` and `num_or_size_splits=head_num` in the `call` method (similar for the other two dense layers).
  
*    `self.controller_output_to_logits`: it should be a dense layer that will be used to map the concatenated controller_output + read_vector_list to the logits that will be used for obtaining the classification labels of the inputs and computing the cross entropy values. Thus, initialize it with `units=self.num_classes`.

---

*   **Caution**: even though most of the discussion below that involve tensors are treated either element-wise or vector-wise, in your implementation please utitlize tensorflow matrix operations as much as possible, as it can avoid strange bugs and increase the speed of your model.
*   As described before, the input arguments of the `call` method are `inputs` and `states`. 
  *  Parse `state` to obtain `prev_controller_state`, `prev_read_vector_list`, `prev_w_r_list`, `prev_w_u`, `prev_M` that come from the previous iteration $t-1$. You may assume that they are zero-filled if $t=0$.
*  Constructing the controller's input had been implemented for you. 
  *  The controller's output will be of the form `(controller_output, controller_states)`.
  *  Why do you think  we should involve `prev_read_vector_list` in the controller's input?
*  Now pass `controller_output` to the dense layers we discussed before, and obtain the read keys, write keys and the interpolation coefficients.
  * Following the suggestion in the Remark 1 above, after applying `tf.split` to the dense layers' outputs, the shapes of your `read_key_list` and `write_key_list` should both be `(head_num, batch_size, memory_vector_dim)`, and the shape of `alpha_list` should be `(head_num, batch_size, 1)`. As an example, `read_key_list[i,p,:]` should be the memory read key for the $i$-th read head for the $p$-th sample in the batch.

*  Before computing the read and write weights and interact with memory, we need to compute `prev_w_lu`, the least used weights from the previous iteration $t-1$. 
  *  To compute `prev_w_lu`, note that for the $p$-th sample in the batch in the previous iteration $t-1$, `prev_w_lu[p,:]` is a vector of binary values with length `num_memory_slots`: defining 
     \begin{equation}
      s(\text{prev_w_u}[p,:], k)= \text{the $k$-th smallest entry in prev_w_u}[p,:]
     \end{equation}
     we have 
     \begin{equation}
      \text{prev_w_lu}[p,i] = 0, \;\; \text{if prev_w_u}[p,i] > s(\text{prev_w_u}[p,:], \text{head_num})
     \end{equation}
     and 
     \begin{equation}
      \text{prev_w_lu}[p,i]=1 \;\; \text{otherwise}
     \end{equation}
  *   Here is one way to implement `compute_w_lu`. Given input argument `prev_w_u` the usage weight from the previous iteration $t-1$ (it has shape `(batch_size, num_memory_slots)`), use `tf.math.top_k` to obtain the desired set of indices from `prev_w_u` (so you should have a `batch_size` number of index sets, each set is of size `head_num`; the overall structure should be of shape `(batch_size, head_num)`). Then use `tf.one_hot` and `tf.reduce_sum` to expand these indices into `prev_w_lu`, which should have shape `(batch_size, num_memory_slots)`. 
    *  From the set of indices with size `(batch_size, head_num)` you used for computing `prev_w_lu`, remember to also construct and return the index corresponding to *the smallest* entry in `prev_w_u[p,:]` for every $p$ (this index also correspond to the memory slot that was least used for the $p$-th sample in the previous iteration); so your returned indices will have size `(batch_size, 1)`.
    *  You may find [tf.math.top_k
](https://www.tensorflow.org/api_docs/python/tf/math/top_k), [tf.one_hot
](https://www.tensorflow.org/api_docs/python/tf/one_hot) and [tf.reduce_sum
](https://www.tensorflow.org/api_docs/python/tf/math/reduce_sum)  useful.
    

*  Now we proceed to compute the read and write weights $w^r_t$ and $w^w_t$.
  *  For the $p$-th sample in the batch, recall that the read key `read_key_list[m,p,:]` is for the $m$-th read head for that sample, and `prev_M[p,j,:]` is the $j$-th memory vector for the $p$-th sample from the previous interation $t-1$ . Then the memory **read** weight `w_r_list[m,p,:]` for the $m$-th read head for the $p$-th sample is a 1D tensor with length `num_memory_slots`, with entries
  \begin{equation}
    \text{w_r_list}[m,p,i] = \frac{\exp(K(\text{prev_M}[p,i,:],\text{read_key_list}[m,p,:]))}{\sum_{j=0}^{\text{num_memory_slots}-1}\exp(K(\text{prev_M}[p,j,:], \text{read_key_list}[m,p,:]))}
  \end{equation}
  where $i=0,1,...,\text{num_memory_slots}-1\$, and
  \begin{equation}
    K(x, y) = \frac{x\cdot y}{\Vert x \Vert_2 \Vert y \Vert_2 + \epsilon}
  \end{equation}
    *  $\epsilon$ is there to ensure numerical stability. $\epsilon=10^{-8}$ seems to be a good choice.
    *  You might find some of the following tensorflow operations useful: [tf.matmul
](https://www.tensorflow.org/api_docs/python/tf/linalg/matmul), [tf.norm
](https://www.tensorflow.org/api_docs/python/tf/norm), [tf.expand_dims
](https://www.tensorflow.org/api_docs/python/tf/expand_dims), [tf.squeeze
](https://www.tensorflow.org/api_docs/python/tf/squeeze), [tf.math.exp
](https://www.tensorflow.org/api_docs/python/tf/math/exp) 

    *  In the suggested setup, the method `compute_read_weights`'s return shape should be `(batch_size, num_memory_slots)`, and `w_r_list` should have shape `(head_num, batch_size, num_memory_slots)`.

  *  Given the $p$-th sample in the batch, the memory **write** weight `w_w_list[m,p,:]` for the $m$-th write head for that sample is of the general form:
     \begin{equation}
      \text{w_w_list}[m,p,i] = \text{Sigmoid}(\text{alpha_list}[m,p,0])\times\text{prev_w_r_list}[m,p,i] + (1 - \text{Sigmoid}(\text{alpha_list}[m,p,0]))\times\text{prev_w_lu}[p,i]
     \end{equation}
     where $i=0,...,\text{num_memory_slots-1}$.
    *  In our suggested setup, method `compute_write_weights`'s return shape should be `(batch_size, num_memory_slots)`, so `w_w_list` should have shape `(head_num, batch_size, num_memory_slots)`.

*  Let us read from memory `prev_M` now.
    *  As we have `w_r_list` with shape `(head_num, batch_size, num_memory_slots)`, to obtain the read vectors, simply carry out the following: for the $m$-th read head for the $p$-th sample, 
      \begin{equation}
        \text{read_vector_list}[m,p,:] = \sum_{j=0}^{\text{num_memory_slots}-1}\text{w_r_list}[m,p,j]\times\text{prev_M}[p,j,:]
      \end{equation}
      where `read_vector_list` has shape `(head_num, batch_size, memory_vector_dim)`.
      *  Please remember that computing with matrices (in contrast to using some kind of for loop) can usually make you code run faster.

* Having obtained the write weights `w_w_list`, we are closer to accessing the content of the memory now. But before that, rememeber that we got a set of indices of size `(batch_size, 1)` from the method `compute_w_lu` that indicated the least used memory slot in the previous iteration $t-1$? We are going to use them to zero out *the least used slot* in the memory first, before the writing operations.
  *  One way of implementation: apply `tf.one_hot` to the set of indices of size `(batch_size, 1)` to obtain a matrix `E` of size `(batch_size, num_memory_slots)` containing one-hot vectors, where `E[p,j]` is 1 if the $j$-th memory slot for the $p$-th sample in the previous iteration was least used. Then we just need to compute the new memory along the line of $M*(1-E)$. So we have obtained `M_erased`, with shape `(batch_size, num_memory_slots, memory_vector_dim)`.

* Now we can write to memory:
  *  Recall that we have already computed `write_key_list` and `w_w_list` with shapes `(head_num, batch_size, memory_vector_dim)` and `(head_num, batch_size, num_memory_slots)` respectively. To write to `M_erased` with the $m$-th write head for the $p$-th sample, simply compute
     \begin{equation}
      \text{M_written}[p,i,:] = \text{M_erased}[p,i,:] + \text{w_w_list}[m,p,i]\times\text{write_key_list}[m,p,:]
     \end{equation}
      *  You might find [tf.matmul
](https://www.tensorflow.org/api_docs/python/tf/linalg/matmul) and [tf.expand_dims
](https://www.tensorflow.org/api_docs/python/tf/expand_dims) useful here.
  

*  Finally, update the usage weight $w^u_t$ following the formula: for the $p$-th sample in the batch,
   \begin{equation}
     \text{w_u}[p,:] = \text{self.gamma}\times\text{prev_w_u}[p,:] + \sum_{i=0}^{\text{head_num}-1}\text{w_r_list}[i,p,:] + \sum_{i=0}^{\text{head_num}-1}\text{w_w_list}[i,p,:]
   \end{equation}
   where `w_u` has shape `(batch_size, num_memory_slots)`, and `self.gamma` is a manually defined free parameter of the model, which we have already set for you.
*  Finally, we update the `state` dictionary , and feed [controller's output + the read vector list] to `self.controller_output_to_logits` which will be used for obtaining the labels for the input samples (already written for you) . Please ensure that all the relevant tensors have the correct shape and content.


In [0]:
class MANNCell(tf.keras.layers.AbstractRNNCell):
  def __init__(self, rnn_size, num_memory_slots, memory_vector_dim, head_num, num_classes=5, gamma=0.95, **kwargs):
    super().__init__(**kwargs)
    ################ Setup ###############################################
    self.rnn_size = rnn_size
    # number of memory slots
    self.num_memory_slots = num_memory_slots
    # size of each memory slot
    self.memory_vector_dim = memory_vector_dim
    self.head_num = head_num
    # memory access head number is the same for both read and 
    # write in our setup  
    self.write_head_num = head_num
    # decay parameter for computing the usage weights
    self.gamma = gamma

    self.num_classes = num_classes
    ########################################################################

    # Controller RNN layer, we use an LSTM
    # Recommended: tf.keras.layers.LSTMCell
    self.controller = tf.keras.layers.LSTMCell(units=self.rnn_size)
    # controller_output 
    #          -> read_key (batch_size, head_num*memory_vector_dim)
    #          -> write_key (batch_size, head_num*memory_vector_dim)
    #          -> alpha (batch_size, head_num), interpolation coefficient for writing to memory
    #
    # units=self.memory_vector_dim*self.head_num for initializing the dense layers
    # for read key and write keys, and units=self.head_num for the dense layer for alpha,
    # and applying tf.split along axis=1 in the call method
    self.controller_output_all=tf.keras.layers.Dense(units=self.memory_vector_dim*self.head_num, use_bias=True)
    
    # This is the dense layer for mapping the controller output + read vector list to 
    # logits (which will then be used for computing the labels and cross-entropy values
    # in NTMOneShotLearningModel). So initialize it with units=self.num_classes.
    self.controller_output_to_logits = tf.keras.layers.Dense(units=self.num_classes, use_bias=True)

  @property
  def state_size(self):
    return self.rnn_size

  # This initializes the dictionary states in MANNCell, and returns the initial state.
  # Please do not change it.
  def zero_state(self, batch_size, rnn_size, dtype):
    one_hot_weight_vector = np.zeros([batch_size, self.num_memory_slots])
    one_hot_weight_vector[..., 0] = 1
    one_hot_weight_vector = tf.constant(one_hot_weight_vector, dtype=tf.float32)
    initial_state = {
            'controller_state': [tf.zeros((batch_size, rnn_size)), tf.zeros((batch_size, rnn_size))],
            'read_vector_list': [tf.zeros([batch_size, self.memory_vector_dim])
                                  for _ in range(self.head_num)],
            'w_r_list': [one_hot_weight_vector for _ in range(self.head_num)],
            'w_u': one_hot_weight_vector,
            'M': tf.constant(np.ones([batch_size, self.num_memory_slots, self.memory_vector_dim]) * 1e-6, dtype=tf.float32)
        }
    return initial_state

  def call(self, inputs, states):
    # read vectors from the previous iteration, extract from states
    prev_read_vector_list = states['read_vector_list'] 
    # state of controller from previous iteration t-1, extract from states
    prev_controller_state = states['controller_state']  
    # Obtain the list of w^r_{t-1}, M_{t-1}, and w^u_{t-1}, extract from states
    prev_w_r_list = states['w_r_list']
    prev_M = states['M']
    prev_w_u = states['w_u']

    # Controller output form the parameters of the read and write vectors
    controller_input = tf.concat([inputs] + prev_read_vector_list, axis=1)
    controller_output, controller_state = self.controller(inputs=controller_input, states=prev_controller_state)

    # Map the controller_output to the read_keys, write_keys, and alphas
    output_parameters = self.controller_output_all(controller_output)
    parameter_list = tf.split(output_parameters, self.head_num, axis=1)
    
    read_key_list=[]
    write_key_list=[]
    sig_alpha=[]
    
    for i, j in enumerate(parameter_list):
        read_keys = j[:, 0:self.memory_vector_dim]
        write_keys = j[:, self.memory_vector_dim:self.memory_vector_dim * 2]
        alphas = j[:, -1:]

        read_key_list.append(tf.tanh(read_keys))
        write_key_list.append(tf.tanh(write_keys))#
        sig_alpha.append(tf.sigmoid(alphas))

    # For every p-th sample in the batch (from iteration t-1), compute the index 
    # corresponding to least used memory slot in prev_M[p,:,:], return as prev_indices.
    # Also compute w^lu_{t-1}, return as prev_w_lu.
    
    prev_indices, prev_w_lu = self.compute_w_lu(prev_w_u)

    # Setup read and write weights
    w_r_list = []
    w_w_list = []
    # We obtain read and write weights for each head
    for i in range(self.head_num):
      # Obtain READ weights
      w_r = self.compute_read_weights(read_key_list[i], prev_M)
      # Obtain WRITE weights
      w_w = self.compute_write_weights(sig_alpha[i], prev_w_r_list[i], prev_w_lu)
      # w_r_list is of shape (head_num, batch_size, num_memory_slots), 
      # and same for w_w_list
      w_r_list.append(w_r)
      w_w_list.append(w_w)

    # Set least used memory slot in prev_M to ZERO, make use of prev_indices!
    M_erased = prev_M * tf.expand_dims(1. - tf.one_hot(prev_indices[:, -1], self.num_memory_slots), dim=2)

    # Read from memory M_{t-1}, using the w_r_list
    read_vector_list = []
    # Iterate over each head
    for i in range(self.head_num):
      # compute read_vector
      read_vector = tf.reduce_sum(tf.expand_dims(w_r_list[i], dim=2) * M_erased, axis=1)
      # read_vector_list should have shape (head_num, batch_size, memory_vector_dim)
      read_vector_list.append(read_vector)


    # Write to memory, form M_t, using the w_w_list and write_keys
    # Iterate over each head
    for i in range(self.head_num):
      
      w = tf.expand_dims(w_w_list[i], axis=2)
      k = tf.expand_dims(read_key_list[i], axis=1)
      M_written = M_erased + tf.matmul(w, k)
    
    # Compute usage weights w^u_t for the current iteration
    w_u = self.gamma * prev_w_u + tf.add_n(w_r_list) + tf.add_n(w_w_list)

    # Concatenate controller's output and the read memory
    # content, they are then fed into a dense layer to obtain the logits,
    # which will be used for obtaininig labels and computing the  cross-entropy 
    # values in NTMOneShotLearningModel below
    mann_output = tf.concat([controller_output] + read_vector_list, axis=1)
    logits = self.controller_output_to_logits(mann_output) 

    state = {
        'controller_state': controller_state,
        'read_vector_list': read_vector_list,
        'w_r_list': w_r_list,
        'w_w_list': w_w_list,
        'w_u': w_u,
        'M': M_written,
    }

    return logits, state

  def compute_read_weights(self, read_key, prev_M):
     
    # Compute the inner products, norms
    read_key = tf.expand_dims(read_key, axis=2)
    M_key_product = tf.matmul(prev_M, read_key)
    read_key_norm = tf.sqrt(tf.reduce_sum(tf.square(read_key), axis=1, keep_dims=True))
    prev_M_norm = tf.sqrt(tf.reduce_sum(tf.square(prev_M), axis=2, keep_dims=True))
    norm_product = prev_M_norm * read_key_norm
    K = tf.squeeze(M_key_product / (norm_product + 1e-8))    
    
    # Compute the exp(K(M,key))'s
    K_exp = tf.exp(K)
     
    # Obtain read weights
    w_r = K_exp / tf.reduce_sum(K_exp, axis=1, keep_dims=True) 
    return w_r

  def compute_write_weights(self, sig_alpha, prev_w_r, prev_w_lu):
    # Compute the write weights
    w_w=sig_alpha * prev_w_r + (1. - sig_alpha) * prev_w_lu
    return w_w

  def compute_w_lu(self, prev_w_u):
    _, indices = tf.nn.top_k(prev_w_u, k=self.num_memory_slots)
    prev_w_lu = tf.reduce_sum(tf.one_hot(indices[:, -self.head_num:], depth=self.num_memory_slots), axis=1)
    return indices, prev_w_lu
  

Already implemented, no need to change.

This class is part of the training loop.

In [0]:
class NTMOneShotLearningModel():
  def __init__(self, model, n_classes, batch_size, seq_length, image_width, image_height,
                rnn_size, num_memory_slots, rnn_num_layers, read_head_num, write_head_num, memory_vector_dim, learning_rate):
    self.output_dim = n_classes

    # Note: the images are flattened to 1D tensors
    # The input data structure is of the following form:
    # self.x_image[i,j,:] = jth image in the ith sequence (or, episode)
    self.x_image = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, image_width * image_height])
    # Model's output label is one-hot encoded
    # The data structure is of the following form:
    # self.x_label[i,j,:] = one-hot label of the jth image in 
    #             the ith sequence (or, episode)
    self.x_label = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, self.output_dim])
    # Target label is one-hot encoded
    self.y = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_length, self.output_dim])

    if model == 'LSTM':
      # Using a LSTM layer to serve as the controller, no memory
      def rnn_cell(rnn_size):
        return tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
      cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell(rnn_size) for _ in range(rnn_num_layers)])
      state = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
    elif model == 'MANN':
      # Using a MANN network as the controller, with memory
      cell = MANNCell(rnn_size, num_memory_slots, memory_vector_dim,
                                head_num=read_head_num)
      state = cell.zero_state(batch_size=batch_size, rnn_size=rnn_size, dtype=tf.float32)
    
    
    self.state_list = [state]
    # Setup the NTM's output
    self.o = []
    
    # Now iterate over every sample in the sequence 
    for t in range(seq_length):
      output, state = cell(tf.concat([self.x_image[:, t, :], self.x_label[:, t, :]], axis=1), state)
      output = tf.nn.softmax(output, axis=1)
      self.o.append(output)
      self.state_list.append(state)
    # post-process the output of the classifier
    self.o = tf.stack(self.o, axis=1)
    self.state_list.append(state)

    eps = 1e-8
    # cross entropy, between model output labels and target labels
    self.learning_loss = -tf.reduce_mean(  
        tf.reduce_sum(self.y * tf.log(self.o + eps), axis=[1, 2])
    )
    
    self.o = tf.reshape(self.o, shape=[batch_size, seq_length, -1])
    self.learning_loss_summary = tf.summary.scalar('learning_loss', self.learning_loss)

    with tf.variable_scope('optimizer'):
      self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
      self.train_op = self.optimizer.minimize(self.learning_loss)

The training and testing functions

In [0]:
def train(learning_rate, image_width, image_height, n_train_classes, n_test_classes, restore_training, \
         num_epochs, n_classes, batch_size, seq_length, num_memory_slots, augment, save_dir, model_path, tensorboard_dir):
  
  # We always use one-hot encoding of the labels in this experiment
  label_type = "one_hot"

  # Initialize the model
  model = NTMOneShotLearningModel(model=model_path, n_classes=n_classes,\
                    batch_size=batch_size, seq_length=seq_length,\
                    image_width=image_width, image_height=image_height, \
                    rnn_size=rnn_size, num_memory_slots=num_memory_slots,\
                    rnn_num_layers=rnn_num_layers, read_head_num=read_head_num,\
                    write_head_num=write_head_num, memory_vector_dim=memory_vector_dim,\
                    learning_rate=learning_rate)
  print("Model initialized")
  data_loader = OmniglotDataLoader(
      image_size=(image_width, image_height),
      n_train_classses=n_train_classes,
      n_test_classes=n_test_classes
  )
  print("Data loaded")
  # Note: our training loop is in the tensorflow 1.x style
  with tf.Session() as sess:
    if restore_training:
      saver = tf.train.Saver()
      ckpt = tf.train.get_checkpoint_state(save_dir + '/' + model_path)
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      saver = tf.train.Saver(tf.global_variables())
      tf.global_variables_initializer().run()
    train_writer = tf.summary.FileWriter(tensorboard_dir + '/' + model_path, sess.graph)
    print("1st\t2nd\t3rd\t4th\t5th\t6th\t7th\t8th\t9th\t10th\tepoch\tloss")
    for b in range(num_epochs):
      # Test the model
      if b % 100 == 0:
        # Note: the images are flattened to 1D tensors
        # The input data structure is of the following form:
        # x_image[i,j,:] = jth image in the ith sequence (or, episode)
        # And the sequence of 50 images x_image[i,:,:] constitute
        # one episode, and each class (out of 5 classes) has around 10
        # appearances in this sequence, as seq_length = 50 and 
        # n_classes = 5, as specified in the code block below
        # See the details in utils.py, OmniglotDataLoader class
        x_image, x_label, y = data_loader.fetch_batch(n_classes, batch_size, seq_length,
                                  type='test',
                                  augment=augment,
                                  label_type=label_type)
        feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
        output, learning_loss = sess.run([model.o, model.learning_loss], feed_dict=feed_dict)
        merged_summary = sess.run(model.learning_loss_summary, feed_dict=feed_dict)
        train_writer.add_summary(merged_summary, b)
        accuracy = test(seq_length, y, output)
        for accu in accuracy:
          print('%.4f' % accu, end='\t')
        print('%d\t%.4f' % (b, learning_loss))

      # Save model per 2000 epochs
      if b%2000==0 and b>0:
        saver.save(sess, save_dir + '/' + model_path + '/model.tfmodel', global_step=b)

      # Train the model
      x_image, x_label, y = data_loader.fetch_batch(n_classes, batch_size, seq_length, \
                                type='train',
                                augment=augment,
                                label_type=label_type)
      feed_dict = {model.x_image: x_image, model.x_label: x_label, model.y: y}
      sess.run(model.train_op, feed_dict=feed_dict)

# as an input, depending on your setup 
# Note: y is the true labels, and of shape (batch_size, seq_length, 5)
# output is the network's classification labels
def test(seq_length, y, output):
    accuracy=[]
    for i in range(len(y)):
        c=0
        for j in range(seq_length):
            if np.argmax(y[i][j][:])==np.argmax(output[i][j][:]):
                c+=1
        accuracy.append((c*100.0)/(seq_length))
    return accuracy

In [7]:
restore_training = False
label_type = "one_hot"
n_classes = 5
seq_length = 50
augment = True
read_head_num = 4
batch_size = 16
num_epochs = 100000
learning_rate = 1e-3
rnn_size = 200
image_width = 20
image_height = 20
rnn_num_layers = 1
num_memory_slots = 128
memory_vector_dim = 40
shift_range = 1
write_head_num = 4
test_batch_num = 100
n_train_classes = 220
n_test_classes = 60
save_dir = './save/one_shot_learning'
tensorboard_dir = './summary/one_shot_learning'
model_path = 'MANN'
train(learning_rate, image_width, image_height, n_train_classes, n_test_classes, restore_training, \
         num_epochs, n_classes, batch_size, seq_length, num_memory_slots, augment, save_dir, model_path, tensorboard_dir)


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
Use the `axis` argument instead
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model initialized
Entered Dataloader
10.0% data loaded.
20.0% data loaded.
30.0% data loaded.
40.0% data loaded.
50.0% data loaded.
60.0% data loaded.
70.0% data loaded.
80.0% data loaded.
90.0% data loaded.
100.0% data loaded.
Data loaded
1st	2nd	3rd	4th	5th	6th	7th	8th	9th	10th	epoch	loss
16.0000	26.0000	16.0000	14.0000	16.0000	20.0000	20.0000	18.0000	22.0000	20.0000	14.0000	18.0000	22.0000	20.0000	16.0000	24.0000	0	80.8889
18.0000	20.0000	28.0000	8.0000	16.0000	20.0000	24.0000	26.0000	38.0000	34.0000	14.0000	12.0000	22.0000	14.0000	20.0000	28.0000	100	80.6712
12.0000	38.0000	20.0000	14.0000	20.0000	30.0000	18.0000	22.0000	24.0000	16.0000	8.0000	20.0000	22.0000	26.0000	24.0000	24.00

The best accuracy reached by the Neural Turing Machine is 98%

Yes the steps involved in this exercise make sense. One thing that coulc be improved in my opinion is the cosine similarity. The cosine similarity used in Neural Turing Machines is for content based addressing where values of the vectors in memory are compared to the values emitted by the controller. Since the values are more important, I think, normalizing the vectors and using a Minkowski Distance could be better though I'm not sure about it.

Coming to the other steps like zeroing out the least recently used(LRU) memory loations from before, I don't think there should be a very huge difference in performance by how we manipulate the LRU memory locations as they have not been used very frequently in the previous iterations.


