In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
torch.manual_seed(0)

# Configuration
input_size = 3
hidden_size = 5
seq_len = 4
batch_size = 2

# Input with batch_first: (batch, seq_len, input_size)
# Input with batch_first: (seqL_len, batch, input_size)

x = torch.randn(batch_size, seq_len, input_size)
# Built-in RNN with batch_first=True
rnn = nn.RNN(input_size, hidden_size, nonlinearity='relu', batch_first=True)
h0 = torch.zeros(1, batch_size, hidden_size)

# Run built-in RNN
out_builtin, hn_builtin = rnn(x, h0)

# Extract weights and biases
W_ih = rnn.weight_ih_l0.detach()
W_hh = rnn.weight_hh_l0.detach()
b_ih = rnn.bias_ih_l0.detach()
b_hh = rnn.bias_hh_l0.detach()

def rnn_step(x_t, h_t_prev, W_ih, W_hh, b_ih, b_hh):
    return torch.relu(x_t @ W_ih.T + b_ih + h_t_prev @ W_hh.T + b_hh)

def my_simple_rnn(x, h0, W_ih, W_hh, b_ih, b_hh):
    # x.shape: batch_size, seq_len, input_size
    batch_size, seq_len, _ = x.size()
    h_t = h0[0]  # Remove layer dim
    outputs = []

    for t in range(seq_len):
        x_t = x[:, t, :]  # Shape: (batch, input_size)
        h_t = rnn_step(x_t, h_t, W_ih, W_hh, b_ih, b_hh)
        outputs.append(h_t.unsqueeze(1))  # Add seq dim

    output_seq = torch.cat(outputs, dim=1)  # (batch, seq_len, hidden)
    return output_seq, h_t.unsqueeze(0)

# Run manual RNN with same weights
out_manual, hn_manual = my_simple_rnn(x, h0, W_ih, W_hh, b_ih, b_hh)

# Verify numerical equivalence
print("Output difference (L2):", torch.norm(out_builtin - out_manual).item())
print("Hidden difference (L2):", torch.norm(hn_builtin - hn_manual).item())
print("Outputs match:", torch.allclose(out_builtin, out_manual, atol=1e-6))
print("Final hidden states match:", torch.allclose(hn_builtin, hn_manual, atol=1e-6))

Output difference (L2): 1.0106459313874439e-07
Hidden difference (L2): 6.143906006172983e-08
Outputs match: True
Final hidden states match: True


In [None]:
input_size = 3
hidden_size = 4
seq_len = 5
batch_size = 2

# Input with batch_first: (batch, seq_len, input_size)
x = torch.randn(batch_size, seq_len, input_size)

# Built-in RNN with ReLU and batch_first=True
rnn = nn.RNN(input_size, hidden_size, nonlinearity='relu', batch_first=True)
h0 = torch.zeros(1, batch_size, hidden_size)

# Built-in RNN forward
out_builtin, hn_builtin = rnn(x, h0)

# Custom recurrent cell with ReLU
class RecurrentCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.ih = nn.Linear(input_size, hidden_size)
        self.hh = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.ReLU()

    def forward(self, x, h):
        out = self.activation(self.ih(x) + self.hh(h))
        return out, out

# Create custom cell & copy weights from built-in RNN
cell = RecurrentCell(input_size, hidden_size)
with torch.no_grad():
    cell.ih.weight.copy_(rnn.weight_ih_l0)
    cell.ih.bias.copy_(rnn.bias_ih_l0)
    cell.hh.weight.copy_(rnn.weight_hh_l0)
    cell.hh.bias.copy_(rnn.bias_hh_l0)

# Manual RNN forward pass (batch first)
h_manual = h0[0]  # shape: (batch, hidden_size)
outputs = []

for t in range(seq_len):
    x_t = x[:, t, :]  # shape: (batch, input_size)
    h_manual, _ = cell(x_t, h_manual)
    outputs.append(h_manual.unsqueeze(1))  # keep time dimension

out_manual = torch.cat(outputs, dim=1)      # (batch, seq_len, hidden)
hn_manual = h_manual.unsqueeze(0)           # (1, batch, hidden)

# Compare with built-in
print("Output difference:", torch.norm(out_builtin - out_manual).item())
print("Hidden state difference:", torch.norm(hn_builtin - hn_manual).item())
print("Output match?", torch.allclose(out_builtin, out_manual, atol=1e-6))
print("Hidden match?", torch.allclose(hn_builtin, hn_manual, atol=1e-6))

Output difference: 3.6263745073483733e-07
Hidden state difference: 2.9802322387695312e-08
Output match? True
Hidden match? True


In [None]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        layers = []
        if num_layers >= 1:
            layers.append(RecurrentCell(input_size, hidden_size))
        for _ in range(1, num_layers):
            layers.append(RecurrentCell(hidden_size, hidden_size))

        self.recurrent_cells = nn.ModuleList(layers)
        self.dense = nn.Linear(hidden_size, input_size)

    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)
        h = torch.zeros(batch_size, self.num_layers, self.hidden_size)

        rnn_out = torch.zeros(batch_size, self.num_layers, self.hidden_size)

        for t in range(seq_len):
            out = x[:, t, :]
            for i in range(len(self.recurrent_cells)):
                h_prev = h[:, i, :]
                out, h_updated = self.recurrent_cells[i](out, h_prev)
                h = h.clone()
                h[:, i, :] = h_updated
                rnn_out[:, i, :] = out

        out = self.dense(rnn_out[:, -1, :])

        return out

model = Model(1, 5, 3)

In [None]:
summary(model, (3, 1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                    [-1, 5]              10
            Linear-2                    [-1, 5]              30
              ReLU-3                    [-1, 5]               0
     RecurrentCell-4         [[-1, 5], [-1, 5]]               0
            Linear-5                    [-1, 5]              30
            Linear-6                    [-1, 5]              30
              ReLU-7                    [-1, 5]               0
     RecurrentCell-8         [[-1, 5], [-1, 5]]               0
            Linear-9                    [-1, 5]              30
           Linear-10                    [-1, 5]              30
             ReLU-11                    [-1, 5]               0
    RecurrentCell-12         [[-1, 5], [-1, 5]]               0
           Linear-13                    [-1, 5]              10
           Linear-14                   