<a href="https://colab.research.google.com/github/DanielHolzwart/RNN-dimension-check/blob/main/RNN_dimension_check.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#quick check to evaluate the dimensions of an RNN with 2 layers
import torch
import torch.nn as nn

In [79]:
torch.manual_seed(1)
rnn_layer = nn.RNN(input_size=5, hidden_size=4,num_layers=2, batch_first=True)

w_xh0 = rnn_layer.weight_ih_l0
w_hh0 = rnn_layer.weight_hh_l0
b_xh0 = rnn_layer.bias_ih_l0
b_hh0 = rnn_layer.bias_hh_l0
w_xh1 = rnn_layer.weight_ih_l1
w_hh1 = rnn_layer.weight_hh_l1
b_xh1 = rnn_layer.bias_ih_l1
b_hh1 = rnn_layer.bias_hh_l1
print('W_xh shape:', w_xh0.shape)
print('W_hh shape:', w_hh0.shape)
print('b_xh shape:', b_xh0.shape)
print('b_hh shape:', b_hh0.shape)
print('W_xh shape:', w_xh1.shape)
print('W_hh shape:', w_hh1.shape)
print('b_xh shape:', b_xh1.shape)
print('b_hh shape:', b_hh1.shape)

W_xh shape: torch.Size([4, 5])
W_hh shape: torch.Size([4, 4])
b_xh shape: torch.Size([4])
b_hh shape: torch.Size([4])
W_xh shape: torch.Size([4, 4])
W_hh shape: torch.Size([4, 4])
b_xh shape: torch.Size([4])
b_hh shape: torch.Size([4])


In [80]:
#putting the weights and bias into a dictonary for looping in the next step
w_xh = {0 : w_xh0, 1 : w_xh1}
w_hh = {0 : w_hh0, 1 : w_hh1}
b_xh = {0 : b_xh0, 1 : b_xh1}
b_hh = {0 : b_hh0, 1 : b_hh1}

In [87]:
x_seq = torch.tensor([[1.0]*5, [2.0]*5, [3.0]*5]).float() #for words the dimensions of x_seq would be (sequence_length, embedding dimension)
# output of the simple RNN:
output, hn = rnn_layer(torch.reshape(x_seq, (1, 3, 5)))
## manually computing the output:
out_man = [[],[]]
for layers in range(rnn_layer.num_layers):
    for t in range(3):
        if layers == 0:
            xt = torch.reshape(x_seq[t], (1, 5))
        else:
            xt = out_man[layers - 1][t]
        print(f'Time step {t} at layer {layers} =>')
        print('Input:', xt.detach().numpy())
        ht = torch.matmul(xt, torch.transpose(w_xh[layers], 0, 1))
        print('Hidden:', ht.detach().numpy())
        if t > 0:
            prev_h = out_man[layers][t-1]
        else:
            prev_h = torch.zeros((ht.shape))
        ot = ht + torch.matmul(prev_h, torch.transpose(w_hh[layers], 0, 1)) + b_hh[layers] + b_xh[layers]
        ot = torch.tanh(ot)
        out_man[layers].append(ot)
        print('Output (manual) :', ot.detach().numpy())
        if layers == (rnn_layer.num_layers - 1):
            print('RNN output', output[:, t].detach().numpy())
        print()

Time step 0 at layer 0 =>
Input: [[1. 1. 1. 1. 1.]]
Hidden: [[-0.29602224  0.45965433  0.11465555  0.6181393 ]]
Output (manual) : [[0.56055534 0.16334416 0.6066464  0.3728287 ]]

Time step 1 at layer 0 =>
Input: [[2. 2. 2. 2. 2.]]
Hidden: [[-0.5920445   0.91930866  0.22931111  1.2362787 ]]
Output (manual) : [[0.05261802 0.67545795 0.6944291  0.6304203 ]]

Time step 2 at layer 0 =>
Input: [[3. 3. 3. 3. 3.]]
Hidden: [[-0.88806677  1.378963    0.34396666  1.8544179 ]]
Output (manual) : [[-0.350783    0.8996122   0.90739036  0.69391143]]

Time step 0 at layer 1 =>
Input: [[0.56055534 0.16334416 0.6066464  0.3728287 ]]
Hidden: [[-0.21185923 -0.47037655  0.5837881   0.17492902]]
Output (manual) : [[-0.3641145  -0.5368899   0.14999884  0.32872772]]
RNN output [[-0.3641145  -0.5368899   0.14999884  0.32872772]]

Time step 1 at layer 1 =>
Input: [[0.05261802 0.67545795 0.6944291  0.6304203 ]]
Hidden: [[ 0.06813398 -0.528718    0.53069365  0.08260797]]
Output (manual) : [[ 0.0329843  -0.453678  

In [114]:
#finally, we can also quickly manually compare the manual output with hn coming from the rnn_layer as defined above
print('--- checking out_man to outputs --- \n')
for i in range(3):
    print(torch.allclose(out_man[rnn_layer.num_layers-1][i],output[:,i,:]))
print()
print('--- checking hn to last output --- \n')
for i in range(rnn_layer.num_layers):
    print(torch.allclose(out_man[i][2],hn[i]))

--- checking out_man to outputs --- 

True
True
True

--- checking hn to last output --- 

True
True
