# Unidirectional RNN

In [9]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np

In [10]:
# define some parameters
BATCH_SIZE = 2
SEQ_LEN = 3
INPUT_SIZE = 4
HIDDEN_SIZE = 2

In [11]:
input_tensor = np.random.randn(BATCH_SIZE, SEQ_LEN, INPUT_SIZE).astype(np.float32)
w_ih = np.random.randn(INPUT_SIZE, HIDDEN_SIZE).astype(np.float32)
w_hh = np.random.randn(HIDDEN_SIZE, HIDDEN_SIZE).astype(np.float32)
b = np.random.randn(HIDDEN_SIZE).astype(np.float32)

In [12]:
def tf_unidirectional_rnn(input_, w_ih_, w_hh_, b_):
    inputs = tf.keras.layers.Input(shape=(SEQ_LEN, INPUT_SIZE))
    rnn_layer = tf.keras.layers.SimpleRNN(HIDDEN_SIZE, return_sequences=True)
    outputs = rnn_layer(inputs)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    # set the initial weight of the RNN to be the same as the scratch
    model.layers[1].set_weights([w_ih_, w_hh_, b_])
    return model(input_)    

tf_output = tf_unidirectional_rnn(input_tensor, w_ih, w_hh, b)
print(tf_output)

tf.Tensor(
[[[ 0.99367493 -0.9263422 ]
  [-0.9670784  -0.9899978 ]
  [-0.9992837  -0.9921832 ]]

 [[ 0.8910041  -0.02237447]
  [ 0.75538677 -0.99566084]
  [-0.99960285 -0.98956466]]], shape=(2, 3, 2), dtype=float32)


In [13]:
def scratch_unidirectional_rnn(input_, w_ih_, w_hh_, b_):
    hidden = np.zeros(shape=(BATCH_SIZE, HIDDEN_SIZE)).astype(np.float32)
    output = np.zeros(shape=(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)).astype(np.float32)
    for seq_idx in range(SEQ_LEN):
        hidden = tf.linalg.matmul(input_[:, seq_idx, :], w_ih_) + tf.linalg.matmul(hidden, w_hh_) + b_
        hidden = tf.math.tanh(hidden)
        output[:, seq_idx, :] = hidden
    return output
scratch_output = scratch_unidirectional_rnn(input_tensor, w_ih, w_hh, b)
print(scratch_output)

[[[ 0.99367493 -0.9263422 ]
  [-0.9670784  -0.9899978 ]
  [-0.9992837  -0.9921832 ]]

 [[ 0.8910041  -0.02237447]
  [ 0.7553868  -0.99566084]
  [-0.99960285 -0.98956466]]]


In [14]:
# check whether the outputs of the two functions are very close
print(np.allclose(tf_output.numpy(), scratch_output))

True


# Bidrectional RNN

In [16]:
w_ih_reverse = np.random.randn(INPUT_SIZE, HIDDEN_SIZE).astype(np.float32)
w_hh_reverse = np.random.randn(HIDDEN_SIZE, HIDDEN_SIZE).astype(np.float32)
b_reverse = np.random.randn(HIDDEN_SIZE).astype(np.float32)

In [27]:
def tf_bidirectional_rnn(input_, w_ih_, w_hh_, b_, w_ih_reverse_, w_hh_reverse_, b_reverse_):
    inputs = tf.keras.layers.Input(shape=(SEQ_LEN, INPUT_SIZE), dtype=tf.float32)
    rnn_layer = tf.keras.layers.Bidirectional(tf.keras.layers.SimpleRNN(HIDDEN_SIZE, return_sequences=True))
    outputs = rnn_layer(inputs)
    model = tf.keras.Model(inputs = inputs, outputs=outputs)
    rnn_layer.set_weights([w_ih_, w_hh_, b_, w_ih_reverse_, w_hh_reverse_, b_reverse_])
    return model(input_)

tf_output = tf_bidirectional_rnn(input_tensor, w_ih, w_hh, b, w_ih_reverse, w_hh_reverse, b_reverse)

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[ 0.99367493, -0.9263422 , -0.09472295, -0.07357155],
        [-0.9670784 , -0.9899978 , -0.6786357 , -0.996604  ],
        [-0.9992837 , -0.9921832 ,  0.8756912 ,  0.990144  ]],

       [[ 0.8910041 , -0.02237447,  0.91867334,  0.41186035],
        [ 0.75538677, -0.99566084, -0.7954552 , -0.9999865 ],
        [-0.99960285, -0.98956466,  0.6120095 ,  0.99115074]]],
      dtype=float32)>

In [32]:
def scratch_bidirectional_rnn(input_, w_ih_, w_hh_, b_, w_ih_reverse_, w_hh_reverse_, b_reverse_):
    forward_output = scratch_unidirectional_rnn(input_, w_ih_, w_hh_, b_)
    reversed_input = tf.reverse(input_, axis=[1])
    backward_output = scratch_unidirectional_rnn(reversed_input, w_ih_reverse_, w_hh_reverse_, b_reverse_)
    backward_output = tf.reverse(backward_output, axis=[1])
    return tf.concat([forward_output, backward_output], axis=2)


scratch_output = scratch_bidirectional_rnn(input_tensor, w_ih, w_hh, b, w_ih_reverse, w_hh_reverse, b_reverse)

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[ 0.99367493, -0.9263422 , -0.09472291, -0.07357161],
        [-0.9670784 , -0.9899978 , -0.6786357 , -0.996604  ],
        [-0.9992837 , -0.9921832 ,  0.8756912 ,  0.990144  ]],

       [[ 0.8910041 , -0.02237447,  0.91867334,  0.4118603 ],
        [ 0.7553868 , -0.99566084, -0.7954552 , -0.9999865 ],
        [-0.99960285, -0.98956466,  0.6120095 ,  0.99115074]]],
      dtype=float32)>

In [33]:
print(np.allclose(tf_output, scratch_output))

True
