# Vectored Read and Write Vectors Generation

In [11]:
import tensorflow as tf
import numpy as np

In [58]:
def ReadVector(M_t, w_t):
    '''
    Computes the Read Vector of EACH HEAD for the entire Batch at once.
    
    M_t: Memory of size (Batch_size,N,M) at time t.
    w_t: (Batch_size,N), Weighting generated by READ HEAD at time t for reading memory location
    
    RETURNS:
    
    r_t: (Batch_size,M) The Read Vector
    '''
    batch_size,N,M = M_t.shape[0], M_t.shape[1], M_t.shape[2]
    
    #tol = 0.01
    #assert (np.sum(w_t) >= 1.0 - tol) & (np.sum(w_t) <= 1.0 + tol)
    r_t = tf.reshape(  tf.matmul(tf.reshape(w_t,(batch_size,1,N)),M_t), (batch_size,M)   )
    
    #assert r_t.shape == (batch_size,M)
    
    return r_t


## Rough Work

In [126]:
batch_size = 68
features = 8
inputs = tf.random.uniform((batch_size,features))
n_RH = 2
n_WH = 2
N = 100
M = 20
M_prev = tf.random.uniform((batch_size,N,M))
w_t = tf.random.uniform((batch_size,N))
a_t = tf.random.uniform((batch_size,M))
e_t = tf.random.uniform((batch_size,M))

In [27]:
%timeit np.array([np.dot(w_t[i],M_prev[i]) for i in range(batch_size)])

41.9 ms ± 451 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [62]:
wt_reshaped = tf.reshape(w_t,(w_t.shape[0],1,w_t.shape[1]))

In [63]:
wt_reshaped.shape

TensorShape([54, 1, 120])

In [84]:
M_prev.shape

TensorShape([54, 120, 28])

In [88]:
tf.matmul(tf.reshape(w_t,(batch_size,1,N)),M_prev)

<tf.Tensor: id=46435, shape=(54, 1, 28), dtype=float32, numpy=
array([[[32.767307, 29.67947 , 29.991352, ..., 31.570396, 33.52847 ,
         34.785275]],

       [[28.586918, 29.211546, 30.81921 , ..., 31.666348, 26.021385,
         30.774918]],

       [[32.036026, 30.057602, 30.28217 , ..., 30.137281, 30.679464,
         26.555609]],

       ...,

       [[29.969158, 30.04381 , 30.72913 , ..., 28.836475, 31.721498,
         28.783731]],

       [[31.592222, 29.157991, 31.915648, ..., 31.20196 , 27.410116,
         30.2157  ]],

       [[31.946661, 28.97745 , 26.739243, ..., 27.836405, 28.473106,
         25.06023 ]]], dtype=float32)>

## Testing 

In [59]:
%timeit ReadVector(M_prev,w_t)

3.03 ms ± 605 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Much Faster!

Write Vector Generation

In [123]:
def WriteOnMemory(M_prev, w_t, e_t, a_t):
    
    '''
    Computes the updated Memory Matrix for the each example in Batch at once
    
    M_prev: Memory Matrix at the previous time step of size (Batch_size,N,M)
    w_t: (Batch_size,N), Weighting generated by WRITE HEAD at time t for Writing to the memory locations.
    e_t: (Batch_size,M), Erase vector generated by WRITE HEAD.
    a_t: (Batch_size,M), Add vector generated by WRITE HEAD.
    
    RETURNS:
    
    M_t: (Batch_size,N,M), New Memory Matrix after Erasing/Adding/Combination of new instances.
    '''
    
    (batch_size,N,M) = M_prev.shape
    
    M_hat_t = tf.multiply( M_prev, 1-tf.multiply(tf.reshape(w_t,(batch_size,N,1)),tf.reshape(e_t,(batch_size,1,M))))
    #^Of shape [batch_size,N,M]
    
    
    assert M_hat_t.shape == M_prev.shape
    
    M_t = M_prev + tf.multiply(tf.reshape(w_t,(batch_size,N,1)),tf.reshape(a_t,(batch_size,1,M)))
    
    assert M_t.shape == M_prev.shape
    
    return M_t

## Rough Work

In [106]:
tf.multiply(np.array([[1],[2],[3]]),np.array([[1,2],[3,4],[5,6]]))

<tf.Tensor: id=46467, shape=(3, 2), dtype=int64, numpy=
array([[ 1,  2],
       [ 6,  8],
       [15, 18]])>

In [111]:
tf.reshape(w_t,(batch_size,N,1)).shape

TensorShape([54, 120, 1])

In [101]:
tf.reshape(e_t,(batch_size,1,M)).shape

TensorShape([54, 1, 28])

In [114]:
tf.multiply( M_prev, 1-tf.multiply(tf.reshape(w_t,(batch_size,N,1)),tf.reshape(e_t,(batch_size,1,M))))

<tf.Tensor: id=46512, shape=(54, 120, 28), dtype=float32, numpy=
array([[[0.12466209, 0.01185953, 0.4066372 , ..., 0.6541431 ,
         0.32235393, 0.59315   ],
        [0.38001245, 0.6197917 , 0.5254241 , ..., 0.2955763 ,
         0.611064  , 0.299763  ],
        [0.20401934, 0.37730828, 0.7601894 , ..., 0.927123  ,
         0.56642765, 0.46336251],
        ...,
        [0.7103551 , 0.2320269 , 0.621648  , ..., 0.07797381,
         0.3594908 , 0.21159324],
        [0.15581657, 0.6271317 , 0.5058738 , ..., 0.02510409,
         0.98569417, 0.34371197],
        [0.32510385, 0.41160035, 0.5350565 , ..., 0.19147713,
         0.40304554, 0.29897645]],

       [[0.2519069 , 0.08076161, 0.06650317, ..., 0.1413985 ,
         0.51315796, 0.6752432 ],
        [0.23208766, 0.01559431, 0.12047119, ..., 0.55408657,
         0.68951213, 0.1678423 ],
        [0.25133756, 0.03863765, 0.18350871, ..., 0.38185447,
         0.07543331, 0.7184113 ],
        ...,
        [0.8956722 , 0.0336101 , 0.17624776

In [121]:
M_prev + tf.multiply(tf.reshape(w_t,(batch_size,N,1)),tf.reshape(a_t,(batch_size,1,M)))

<tf.Tensor: id=46545, shape=(54, 120, 28), dtype=float32, numpy=
array([[[0.6302588 , 0.21265168, 1.085969  , ..., 0.90555835,
         0.74730134, 1.7050848 ],
        [0.6483498 , 1.046634  , 0.8121977 , ..., 0.3457772 ,
         0.80191195, 0.72178245],
        [0.25161204, 0.42268232, 0.8271928 , ..., 0.9525206 ,
         0.6070089 , 0.5571425 ],
        ...,
        [1.0095346 , 0.39979222, 0.8928581 , ..., 0.09815055,
         0.5260487 , 0.57347494],
        [0.30547768, 0.88500136, 0.69500446, ..., 0.03622218,
         1.1205815 , 0.63789976],
        [0.7252695 , 1.0276866 , 0.99679875, ..., 0.24996248,
         0.68706787, 0.9414414 ]],

       [[0.5300431 , 1.2941301 , 1.0995653 , ..., 0.80758834,
         1.4802761 , 1.5176778 ],
        [0.39700538, 0.5008979 , 0.7342089 , ..., 1.145437  ,
         1.2882457 , 0.52745974],
        [0.4561191 , 0.680219  , 1.0277897 , ..., 1.0280552 ,
         0.6370976 , 1.329173  ],
        ...,
        [1.2459085 , 1.0262616 , 1.4976895 

## Testing

In [127]:
WriteOnMemory(M_prev, w_t, e_t, a_t)

<tf.Tensor: id=46618, shape=(68, 100, 20), dtype=float32, numpy=
array([[[0.9269947 , 0.18070653, 0.56467706, ..., 0.669017  ,
         0.5463979 , 1.0143723 ],
        [0.741804  , 0.4565654 , 0.70919394, ..., 1.0868827 ,
         0.7759407 , 0.72396123],
        [0.38641325, 0.25682545, 0.47741446, ..., 0.737785  ,
         0.83834195, 0.7259721 ],
        ...,
        [0.68615866, 0.88903993, 0.25532448, ..., 0.45391294,
         0.5564223 , 0.6809142 ],
        [0.23190612, 0.13355017, 0.56690216, ..., 1.0081687 ,
         0.5462    , 0.7016797 ],
        [0.24384663, 0.4104563 , 0.44970807, ..., 0.54653764,
         0.64238185, 0.39973623]],

       [[1.564601  , 0.6283729 , 0.7955998 , ..., 1.1825564 ,
         1.0247817 , 1.2223755 ],
        [0.82689345, 1.4993862 , 1.3725848 , ..., 1.5167229 ,
         0.7009622 , 0.4465777 ],
        [1.8378738 , 1.0191371 , 0.8405704 , ..., 1.2319281 ,
         1.6295192 , 0.5357495 ],
        ...,
        [1.733243  , 1.2638831 , 1.0667794 