In [1]:
import torch
from torch import nn

In [41]:
# RNN forward
input_size = 2
hidden_size = 3
batch_size = 4
sequence_length = 5
rnn = nn.RNN(
    input_size=input_size,
    hidden_size=hidden_size,
    batch_first=True,
)


w_ih = rnn.weight_ih_l0.detach().numpy()  # (3, 2) = (hidden_size, input_size)
w_hh = rnn.weight_hh_l0.detach().numpy()
b_ih = rnn.bias_ih_l0.detach().numpy() # (hidden_size,) = (3)
b_hh = rnn.bias_hh_l0.detach().numpy()


print("PARAMS:")
print(f"{w_ih=}\n{w_hh=}\n{b_ih=}\n{b_hh=}")
print()


print("FORWARD PASS:")
# input.shape = (N, L, H_in) = (batch_size, sequence_length, input_size)
inputs = torch.randn(batch_size, sequence_length, input_size)
print(f"{inputs=}")

# hidden_initial.shape = (num_layers, N, H_out) = (num_layers, batch_size, hidden_size)
hidden_initial = torch.zeros(1, batch_size, hidden_size)
print(f"{hidden_initial=}")

outputs, hidden_final = rnn(inputs, hidden_initial)
# output.shape = (N, L, H_out) = (batch_size, sequence_length, hidden_size)
# hidden_final.shape = (num_layers, N, H_out) = (num_layers, batch_size, hidden_size)
print(f"{outputs=}\n{hidden_final=}")

PARAMS:
w_ih=array([[-0.36250335, -0.06201804],
       [ 0.19342226, -0.4159969 ],
       [-0.2725783 ,  0.3993578 ]], dtype=float32)
w_hh=array([[ 0.06113869,  0.17437363,  0.3263952 ],
       [ 0.16002077,  0.30942285, -0.06392378],
       [-0.38864082,  0.3857612 ,  0.04680109]], dtype=float32)
b_ih=array([-0.31550786,  0.1562801 ,  0.44562936], dtype=float32)
b_hh=array([0.39199525, 0.16152865, 0.03036499], dtype=float32)

FORWARD PASS:
inputs=tensor([[[ 1.1150, -1.9053],
         [-0.8039,  1.6239],
         [-0.7541,  0.3839],
         [-0.8075, -1.3132],
         [ 0.1741,  0.7878]],

        [[-1.3851,  0.8334],
         [-0.1918,  1.0145],
         [ 0.7142, -0.1625],
         [-0.0992,  1.4404],
         [-0.3180, -1.0389]],

        [[ 1.1328,  1.8441],
         [-0.6378, -0.2136],
         [ 0.2844,  0.4213],
         [ 1.1409,  0.4411],
         [ 0.4967, -0.8242]],

        [[ 1.2499,  0.4635],
         [ 0.0251,  0.1511],
         [-0.8519,  1.0837],
         [-0.9049, -

In [42]:
print(outputs[:, -1, :] == hidden_final)

tensor([[[True, True, True],
         [True, True, True],
         [True, True, True],
         [True, True, True]]])


In [49]:
print("MANUAL PASS")
output = torch.zeros(batch_size, sequence_length, hidden_size)
hidden_initial = torch.zeros(1, batch_size, hidden_size)

for batch_idx in range(batch_size):
    h = hidden_initial[:, batch_idx, :]
    for seq_idx in range(sequence_length):
        x = inputs[batch_idx, seq_idx, :]
        o = torch.tanh(x @ w_ih.T + b_ih + h @ w_hh.T + b_hh)
        h = o

        output[batch_idx, seq_idx] = o

hidden_final = output[:, -1, :]
print(f"{output=}")
print(f"{hidden_final=}")

MANUAL PASS
output=tensor([[[-0.2065,  0.8683, -0.5291],
         [ 0.2291, -0.2391,  0.9395],
         [ 0.5406, -0.0849,  0.6028],
         [ 0.5821,  0.6228, -0.0429],
         [ 0.0944,  0.3026,  0.6382]],

        [[ 0.4830, -0.2884,  0.8294],
         [ 0.3213, -0.2034,  0.5871],
         [ 0.0035,  0.4418,  0.0406],
         [ 0.1131, -0.1644,  0.8481],
         [ 0.4709,  0.5381,  0.0799]],

        [[-0.4207, -0.2263,  0.7181],
         [ 0.4543,  0.0998,  0.5878],
         [ 0.1823,  0.2576,  0.4269],
         [-0.1674,  0.4108,  0.3711],
         [ 0.1294,  0.6823,  0.2471]],

        [[-0.3845,  0.3511,  0.3099],
         [ 0.1944,  0.2795,  0.6799],
         [ 0.5375, -0.2200,  0.8352],
         [ 0.6595,  0.7232, -0.2975],
         [ 0.1658,  0.9126, -0.3636]]])
hidden_final=tensor([[ 0.0944,  0.3026,  0.6382],
        [ 0.4709,  0.5381,  0.0799],
        [ 0.1294,  0.6823,  0.2471],
        [ 0.1658,  0.9126, -0.3636]])
