# Multi-Cell RNN

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np

In [2]:
# define some parameters
BATCH_SIZE = 2
SEQ_LEN = 3
INPUT_SIZE = 4
HIDDEN_SIZE_1 = 2
HIDDEN_SIZE_2 = 3

In [3]:
input_tensor = np.random.randn(BATCH_SIZE, SEQ_LEN, INPUT_SIZE).astype(np.float32)
w_ih_1 = np.random.randn(INPUT_SIZE, HIDDEN_SIZE_1).astype(np.float32)
w_hh_1 = np.random.randn(HIDDEN_SIZE_1, HIDDEN_SIZE_1).astype(np.float32)
b_1 = np.random.randn(HIDDEN_SIZE_1).astype(np.float32)

w_ih_2 = np.random.randn(HIDDEN_SIZE_1, HIDDEN_SIZE_2).astype(np.float32)
w_hh_2 = np.random.randn(HIDDEN_SIZE_2, HIDDEN_SIZE_2).astype(np.float32)
b_2 = np.random.randn(HIDDEN_SIZE_2).astype(np.float32)

In [4]:
def multi_cell_rnn(input_, w_ih_1_, w_hh_1_, b_1_, w_ih_2_, w_hh_2_, b_2_):
    cell1 = tf.keras.layers.SimpleRNNCell(HIDDEN_SIZE_1)
    cell2 = tf.keras.layers.SimpleRNNCell(HIDDEN_SIZE_2)
    multi_cell_rnn = tf.keras.layers.RNN([cell1, cell2])
    inputs = tf.keras.layers.Input(shape=(SEQ_LEN, INPUT_SIZE), dtype=tf.float32)
    outputs = multi_cell_rnn(inputs)
    model = tf.keras.Model(inputs, outputs)
    multi_cell_rnn.set_weights([w_ih_1_, w_hh_1_, b_1_, w_ih_2_, w_hh_2_, b_2_])
    return model(input_)

multi_cell_output = multi_cell_rnn(input_tensor, w_ih_1, w_hh_1, b_1, w_ih_2, w_hh_2, b_2)
print(multi_cell_output)

tf.Tensor(
[[-0.9998453   0.9383833   0.44583783]
 [-0.9999621   0.99910057  0.96001154]], shape=(2, 3), dtype=float32)


In [5]:
def multi_layer_rnn(input_, w_ih_1_, w_hh_1_, b_1_, w_ih_2_, w_hh_2_, b_2_):
    inputs = tf.keras.layers.Input(shape=(SEQ_LEN, INPUT_SIZE), dtype=tf.float32)
    layer1 = tf.keras.layers.SimpleRNN(HIDDEN_SIZE_1, return_sequences=True)
    layer2 = tf.keras.layers.SimpleRNN(HIDDEN_SIZE_2)    
    x = layer1(inputs)
    x = layer2(x)
    model = tf.keras.Model(inputs, x)
    layer1.set_weights([w_ih_1_, w_hh_1_, b_1_])
    layer2.set_weights([w_ih_2_, w_hh_2_, b_2_])
    return model(input_)

multi_layer_output = multi_layer_rnn(input_tensor, w_ih_1, w_hh_1, b_1, w_ih_2, w_hh_2, b_2)
print(multi_layer_output)

tf.Tensor(
[[-0.9998453   0.9383833   0.44583783]
 [-0.9999621   0.99910057  0.96001154]], shape=(2, 3), dtype=float32)


In [6]:
print(np.allclose(multi_cell_output, multi_layer_output))

True
