# Neural Turing Machine

### Reading and Writing Vector Generation Functions and Testing

In [2]:
import tensorflow as tf
import numpy as np

In [59]:
def ReadVector(M_t, w_t):
    '''
    M_t: Memory of size (N,M) at time t.
    w_t: (N,), Weighting generated by READ HEAD at time t for reading memory location
    
    RETURNS:
    
    r_t: The Read Vector
    '''
    
    tol = 0.01
    assert (np.sum(w_t) >= 1.0 - tol) & (np.sum(w_t) <= 1.0 + tol)
    r_t = np.dot(tf.reshape(w_t,(1,M_t.shape[0])),M_t)
    assert r_t.shape == (1,M_t.shape[1])
    
    return tf.reshape(r_t,(-1,))

In [77]:
def WriteOnMemory(M_prev, w_t, e_t, a_t):
    
    '''
    M_prev: Memory Matrix at the previous time step of size (N,M)
    w_t: (N,), Weighting generated by WRITE HEAD at time t for Writing to the memory locations.
    e_t: (M,), Erase vector generated by WRITE HEAD.
    a_t: (M,), Add vector generated by WRITE HEAD.
    
    RETURNS:
    
    M_t: New Memory Matrix after Erasing/Adding new instances.
    '''
    
    assert np.all(e_t[(e_t < 1) & (e_t > 0)] == e_t) == True
    assert np.sum(w_t) == 1.0
    
    (N,M) = M_prev.shape
    
    assert w_t.shape == (N,)
    assert e_t.shape == (M,)
    assert a_t.shape == (M,)
    
    matrix = np.dot(tf.reshape(w_t,(N,1)),tf.reshape(e_t,(1,M)))
    matrix = 1 - matrix
    M_hat_t = tf.multiply(M_prev,matrix)
    
    assert M_hat_t.shape == M_prev.shape
    
    matrix2 = np.dot(tf.reshape(w_t,(N,1)),tf.reshape(a_t,(1,M)))
    M_t = (M_hat_t + matrix2)
    
    assert M_t.shape == M_prev.shape
    
    return M_t

## Testing

In [38]:
N = 10
M = 5

In [49]:
a = tf.random.uniform((N,))
a = a/np.sum(a)

In [40]:
mem = tf.random.uniform((N,M))

In [50]:
a

<tf.Tensor: id=44, shape=(10,), dtype=float32, numpy=
array([0.08523568, 0.03617236, 0.06858236, 0.06843124, 0.13535374,
       0.12855709, 0.12970088, 0.14419153, 0.08366798, 0.12010714],
      dtype=float32)>

In [43]:
mem

<tf.Tensor: id=35, shape=(10, 5), dtype=float32, numpy=
array([[0.8119643 , 0.60137   , 0.23534846, 0.44938838, 0.43760514],
       [0.8016536 , 0.58185935, 0.6212783 , 0.8808694 , 0.27468324],
       [0.5584111 , 0.8749579 , 0.04088831, 0.63813174, 0.9779513 ],
       [0.4518342 , 0.1791066 , 0.4633193 , 0.28779483, 0.4036492 ],
       [0.3037132 , 0.31695628, 0.22742224, 0.20577002, 0.36980045],
       [0.16494429, 0.14293563, 0.30317807, 0.30822635, 0.54937124],
       [0.97792554, 0.40442812, 0.9620074 , 0.662657  , 0.59774125],
       [0.38029218, 0.7453276 , 0.17781484, 0.40123653, 0.97076035],
       [0.74211204, 0.27332675, 0.00297976, 0.78554   , 0.35972118],
       [0.38652027, 0.77993643, 0.51441574, 0.8110522 , 0.8275595 ]],
      dtype=float32)>

In [60]:
ReadVector(mem,a)

<tf.Tensor: id=62, shape=(5,), dtype=float32, numpy=
array([0.5199238 , 0.4823144 , 0.35924798, 0.50804204, 0.60960335],
      dtype=float32)>

In [61]:
w = tf.random.uniform((N,))
w = w/np.sum(w)

In [62]:
w

<tf.Tensor: id=71, shape=(10,), dtype=float32, numpy=
array([0.00227009, 0.01585456, 0.14528295, 0.05652477, 0.05780202,
       0.15667509, 0.14441091, 0.1346978 , 0.12770519, 0.15877664],
      dtype=float32)>

In [63]:
a = tf.random.uniform((M,))

In [64]:
a

<tf.Tensor: id=78, shape=(5,), dtype=float32, numpy=
array([0.19372547, 0.8073455 , 0.5351058 , 0.52876794, 0.5324855 ],
      dtype=float32)>

In [65]:
e = tf.random.uniform((M,))

In [66]:
e

<tf.Tensor: id=85, shape=(5,), dtype=float32, numpy=
array([0.8540882 , 0.62966216, 0.8816339 , 0.52414644, 0.94140685],
      dtype=float32)>

In [76]:
WriteOnMemory(mem,w,e,a)

<tf.Tensor: id=73210, shape=(10, 5), dtype=float32, numpy=
array([[0.81082976, 0.60234314, 0.23609218, 0.45005402, 0.43787873],
       [0.7938697 , 0.58885074, 0.62107795, 0.8819326 , 0.27902573],
       [0.517266  , 0.912211  , 0.11339283, 0.6663593 , 0.9215576 ],
       [0.44097123, 0.21836694, 0.47047693, 0.30915675, 0.41226852],
       [0.2999172 , 0.35208663, 0.24676293, 0.23009971, 0.38045642],
       [0.17322433, 0.25532562, 0.3451378 , 0.36575937, 0.5517689 ],
       [0.88528466, 0.48424292, 0.91680205, 0.68885875, 0.5933754 ],
       [0.36263633, 0.790861  , 0.22877616, 0.44413257, 0.91938734],
       [0.6859085 , 0.35445043, 0.07098006, 0.8004853 , 0.38447574],
       [0.3648636 , 0.83014935, 0.52736866, 0.82751065, 0.78840756]],
      dtype=float32)>

In [74]:
%%timeit 
WriteOnMemory(mem,w,e,a)

7.89 ms ± 232 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
