In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size=9, hidden_size=4, output_size=3):
        super(SimpleRNN, self).__init__()

        self.hidden_size = hidden_size

        self.input_to_hidden = nn.Linear(input_size, hidden_size)    # 9x4
        self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size)  # 4x4
        self.hidden_to_output = nn.Linear(hidden_size, output_size)  # 4x3


    def forward(self, inputs):
        steps_output, hidden_states = {}, {}
        h_prev = torch.zeros((1,self.hidden_size)) # (1,4)

        # print("inputs shape:", torch.stack(inputs).shape)  # (10, 9)
        # print(h_prev.shape)  # (1,4)

        hidden_states[-1] = h_prev

        for t in range(len(inputs)):
            x = inputs[t].reshape(1,9)    # (1,9)

            hidden_cur = self.input_to_hidden(x)  # (1,9) * (9,4) = (1,4)

            h_prev = self.hidden_to_hidden(h_prev)  # (1,4) * (4,4) = (1,4)
            h_prev = torch.tanh(hidden_cur + h_prev)  # (1,4) + (1,4) = (1,4)

            y_t = self.hidden_to_output(h_prev)  # (1,4) * (4,3) = (1,3)

            hidden_states[t] = h_prev
            steps_output[t] = y_t

            if t < 2:
                print(f"\nt = {t}")
                print("x shape:", x.shape)
                print("hidden_cur shape:", hidden_cur.shape)
                print("h_prev shape:", h_prev.shape)
                print("y_t shape:", y_t.shape)
                print("----------------------------")

        return steps_output, hidden_states

In [4]:
sequence_length = 10
input_size = 9
hidden_size = 4
output_size = 3

model = SimpleRNN(input_size, hidden_size, output_size)

inputs = [torch.randn(input_size) for _ in range(sequence_length)]

output, hidden_states = model(inputs)
print("\nFinal output keys:", output.keys())


t = 0
x shape: torch.Size([1, 9])
hidden_cur shape: torch.Size([1, 4])
h_prev shape: torch.Size([1, 4])
y_t shape: torch.Size([1, 3])
----------------------------

t = 1
x shape: torch.Size([1, 9])
hidden_cur shape: torch.Size([1, 4])
h_prev shape: torch.Size([1, 4])
y_t shape: torch.Size([1, 3])
----------------------------

Final output keys: dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
