In [1]:
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Define the RNN layer
        # batch_first=True means input/output tensors will have (batch, sequence, feature) shape
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)

        # Define a linear layer to map the hidden state to the output
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Initialize hidden state with zeros
        # hidden_state shape: (num_layers * num_directions, batch_size, hidden_size)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # Pass input through the RNN layer
        # out: output features from the last layer of the RNN for each time step
        # hn: hidden state for the last time step
        out, hn = self.rnn(x, h0)

        # Pass the output of the last time step through the linear layer
        # out[:, -1, :] selects the output of the last time step for all batches
        out = self.fc(out[:, -1, :])
        return out

# Example Usage
if __name__ == '__main__':
    # Hyperparameters
    input_size = 10  # Number of features in each time step
    hidden_size = 20 # Size of the hidden state
    output_size = 1  # Size of the output (e.g., for regression or binary classification)
    num_layers = 1   # Number of stacked RNN layers
    sequence_length = 5 # Number of time steps in each sequence
    batch_size = 4   # Number of sequences in a batch

    # Create a dummy input tensor
    # Input shape: (batch_size, sequence_length, input_size)
    dummy_input = torch.randn(batch_size, sequence_length, input_size)

    # Instantiate the RNN model
    model = SimpleRNN(input_size, hidden_size, output_size, num_layers)

    # Perform a forward pass
    output = model(dummy_input)

    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")

Input shape: torch.Size([4, 5, 10])
Output shape: torch.Size([4, 1])
