# Manually compute the forward pass of the hidden-to-hidden recurrence

In [24]:
import torch 
import torch.nn as nn
torch.manual_seed(1)

rnn_layer = nn.RNN(input_size=5, hidden_size=2, num_layers=1, batch_first=True)
w_xh = rnn_layer.weight_ih_l0
w_hh = rnn_layer.weight_hh_l0
b_xh = rnn_layer.bias_ih_l0
b_hh = rnn_layer.bias_hh_l0

print('W_xh shape:', w_xh.shape)
print('W_hh shape:', w_hh.shape)
print('b_xh shape:', b_xh.shape)
print('b_hh shape:', b_xh.shape)


W_xh shape: torch.Size([2, 5])
W_hh shape: torch.Size([2, 2])
b_xh shape: torch.Size([2])
b_hh shape: torch.Size([2])


In [25]:
# Now, we will call the forward pass on the rnn_layer 
x_seq = torch.tensor([
    [1.0]*5, 
    [2.0]*5, 
    [3.0]*5
    ]).float()
print('---x_seq---')
print(x_seq.shape)

## output of the simple RNN:
output, hn = rnn_layer(torch.reshape(x_seq, (1, 3, 5))) # added batch size dimension = 1
print('---output---')
print(output.shape)
print('----hn----')
print(hn.shape)

---x_seq---
torch.Size([3, 5])
---output---
torch.Size([1, 3, 2])
----hn----
torch.Size([1, 1, 2])


In [28]:
## manually computing the output:
out_man = []
for t in range(3):
    xt = torch.reshape(x_seq[t], (1, 5))
    print(f'time step {t} =>')
    print('     Input               :', xt.numpy())

    ht = torch.matmul(xt, torch.transpose(w_xh, 0, 1))+b_xh
    print('     Hidden              :', ht.detach().numpy())

    if t > 0:
        prev_h = out_man[t-1]
    else:
        prev_h = torch.zeros((ht.shape))
    
    ot = ht + torch.matmul(prev_h, torch.transpose(w_hh, 0, 1)) + b_hh
    ot = torch.tanh(ot)
    out_man.append(ot)
    
    print('     Output (manual)     :', ot.detach().numpy())
    print('     RNN output          :', output[:, t].detach().numpy())
    print()

time step 0 =>
     Input               : [[1. 1. 1. 1. 1.]]
     Hidden              : [[-0.4701929  0.5863904]]
     Output (manual)     : [[-0.3519801   0.52525216]]
     RNN output          : [[-0.3519801   0.52525216]]

time step 1 =>
     Input               : [[2. 2. 2. 2. 2.]]
     Hidden              : [[-0.88883156  1.2364397 ]]
     Output (manual)     : [[-0.68424344  0.76074266]]
     RNN output          : [[-0.68424344  0.76074266]]

time step 2 =>
     Input               : [[3. 3. 3. 3. 3.]]
     Hidden              : [[-1.3074701  1.886489 ]]
     Output (manual)     : [[-0.8649416   0.90466356]]
     RNN output          : [[-0.8649416   0.90466356]]

