# Biderectional Recurrent Neural Network in PyTorch

Turning a simple recurrent neuaral network into a biderectional neural network in PyTorch is extremely easy. All you have to do is to provide the `biderectional` parameter, which you set to true.

In [1]:
import torch
import torch.nn as nn

For the most part we require very similar paramters, but the `D` parameter (dimensionality) is new. We will use this parameter as a multiplier, when we will calculate certain dimensions.

In [2]:
BATCH_SIZE=4
SEQUENCE_LENGTH=5
INPUT_SIZE=2
HIDDEN_SIZE=3
NUM_LAYERS=1
BIDERECTIONAL=True
D = 2 if BIDERECTIONAL == True else 1

We set the `biderectional` parameter to `True`.

In [3]:
rnn = nn.RNN(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, bidirectional=BIDERECTIONAL)

This time we use just one layer, so that we only have `l_0` and no `l_1` set of weights, but we gain weights (and biases) which names end with `_reverse`. This are the weights and biases that we will use to calculate the hidden state when we traverse the sequence from finish to start.

In [4]:
# ---------------------------- #
# weights forward
# ---------------------------- #

# input to hidden weights and biases
w_ih_l0 = rnn.weight_ih_l0
b_ih_l0 = rnn.bias_ih_l0

# hidden to hidden weights and biases
w_hh_l0 = rnn.weight_hh_l0
b_hh_l0 = rnn.bias_hh_l0

# ---------------------------- #
# weights reverse
# ---------------------------- #

# input to hidden weights and biases
w_r_ih_l0 = rnn.weight_ih_l0_reverse
b_r_ih_l0 = rnn.bias_ih_l0_reverse

# hidden to hidden weights and biases
w_r_hh_l0 = rnn.weight_hh_l0_reverse
b_r_hh_l0 = rnn.bias_hh_l0_reverse

The hidden state is going to contain an additional set of outputs due to the reverse traversing, therefore we scale initial hidden state.

In [5]:
sequence = torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, INPUT_SIZE)
h_0 = torch.zeros(D * NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE)

In [6]:
with torch.inference_mode():
    output, h_n = rnn(sequence, h_0)

Once again we recommend you to work throught this manual implementation of a biderectional rnn. This will improve your understanding of the inner workings greatly.

In [7]:
def manual_rnn():
    hidden = h_0.clone()
    output = torch.zeros(SEQUENCE_LENGTH, BATCH_SIZE, D * HIDDEN_SIZE)
    with torch.inference_mode():
        for idx in range(SEQUENCE_LENGTH):
            # use idx -> forward direction
            hidden[0] = torch.tanh(sequence[idx] @ w_ih_l0.T + b_ih_l0 + hidden[0] @ w_hh_l0.T + b_hh_l0)
            output[idx, :, :HIDDEN_SIZE] = hidden[0]
            # use SEQUENCE_LENGTH - 1 -idx -> reverse direction
            hidden[1] = torch.tanh(sequence[SEQUENCE_LENGTH - 1 -idx] @ w_r_ih_l0.T + b_r_ih_l0 + hidden[1] @ w_r_hh_l0.T + b_r_hh_l0)
            output[SEQUENCE_LENGTH - 1 - idx, :, HIDDEN_SIZE:] = hidden[1]
    return output, hidden

In [8]:
manual_output, manual_h_n = manual_rnn()

We compare the results to make sure, that our implementation is correct.

In [9]:
output

tensor([[[ 0.6119,  0.7950,  0.5467,  0.2447, -0.6007,  0.2942],
         [-0.3753,  0.1250,  0.7078,  0.1123,  0.1765, -0.0716],
         [-0.0242,  0.7054,  0.8048,  0.2648, -0.1507,  0.6702],
         [ 0.0957,  0.4837,  0.6330,  0.1857, -0.0234, -0.0326]],

        [[-0.9501, -0.5405,  0.9671, -0.3674,  0.6894,  0.5082],
         [-0.6360,  0.3203,  0.7740, -0.5183, -0.3718, -0.0748],
         [-0.6780,  0.6467,  0.9565,  0.0843,  0.4808,  0.8144],
         [-0.0305,  0.5113,  0.7426, -0.2721, -0.2520, -0.1512]],

        [[-0.6581,  0.8128,  0.7064,  0.2179,  0.0232,  0.3551],
         [-0.6379,  0.6315,  0.8972,  0.1088,  0.6091,  0.4488],
         [ 0.1069,  0.8739,  0.8683,  0.6259,  0.3998,  0.5972],
         [-0.8309, -0.1865,  0.8807, -0.2830,  0.5560, -0.2209]],

        [[ 0.5591,  0.8903,  0.8370,  0.3448, -0.5082,  0.6432],
         [ 0.1753,  0.6623,  0.7057,  0.1980, -0.2932, -0.1916],
         [ 0.3917,  0.6472,  0.7297,  0.1548, -0.5782, -0.0358],
         [ 0.8146, 

In [10]:
manual_output

tensor([[[ 0.6119,  0.7950,  0.5467,  0.2447, -0.6007,  0.2942],
         [-0.3753,  0.1250,  0.7078,  0.1123,  0.1765, -0.0716],
         [-0.0242,  0.7054,  0.8048,  0.2648, -0.1507,  0.6702],
         [ 0.0957,  0.4837,  0.6330,  0.1857, -0.0234, -0.0326]],

        [[-0.9501, -0.5405,  0.9671, -0.3674,  0.6894,  0.5082],
         [-0.6360,  0.3203,  0.7740, -0.5183, -0.3718, -0.0748],
         [-0.6780,  0.6467,  0.9565,  0.0843,  0.4808,  0.8144],
         [-0.0305,  0.5113,  0.7426, -0.2721, -0.2520, -0.1512]],

        [[-0.6581,  0.8128,  0.7064,  0.2179,  0.0232,  0.3551],
         [-0.6379,  0.6315,  0.8972,  0.1088,  0.6091,  0.4488],
         [ 0.1069,  0.8739,  0.8683,  0.6259,  0.3998,  0.5972],
         [-0.8309, -0.1865,  0.8807, -0.2830,  0.5560, -0.2209]],

        [[ 0.5591,  0.8903,  0.8370,  0.3448, -0.5082,  0.6432],
         [ 0.1753,  0.6623,  0.7057,  0.1980, -0.2932, -0.1916],
         [ 0.3917,  0.6472,  0.7297,  0.1548, -0.5782, -0.0358],
         [ 0.8146, 

In [11]:
h_n

tensor([[[ 0.0027,  0.8121,  0.9247],
         [ 0.8281,  0.8654,  0.5082],
         [ 0.1160,  0.6814,  0.8117],
         [ 0.2054, -0.0800,  0.7989]],

        [[ 0.2447, -0.6007,  0.2942],
         [ 0.1123,  0.1765, -0.0716],
         [ 0.2648, -0.1507,  0.6702],
         [ 0.1857, -0.0234, -0.0326]]])

In [12]:
manual_h_n

tensor([[[ 0.0027,  0.8121,  0.9247],
         [ 0.8281,  0.8654,  0.5082],
         [ 0.1160,  0.6814,  0.8117],
         [ 0.2054, -0.0800,  0.7989]],

        [[ 0.2447, -0.6007,  0.2942],
         [ 0.1123,  0.1765, -0.0716],
         [ 0.2648, -0.1507,  0.6702],
         [ 0.1857, -0.0234, -0.0326]]])