In [1]:
import torch
import torch.nn as nn

In [6]:
class RNN(nn.Module):
    """
    Implementation of a Recurrent Neural Network

    __init__()
    input_size[int]: features per timestep, if > 1 is mutltivariate
    hidden_size[int]: hidden size, normally between (64, 2056)
    num_layers[int]: number of rnn stacked cells, if > 1 is DeepRNN, normally between (1, 8)

    forward()
    x[torch.Tensor]: model input of size [batch_size, seq_len, input_size]
    h[torch.Tensor]: initial hidden state of size [batch_size, seq_len, input_size], if None it will be zeros
    """
    def __init__(self, input_size, hidden_size, num_layers=1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Input layer
        self.layers = nn.ModuleList([nn.ModuleDict({
            'w_ih' : nn.Linear(input_size, hidden_size, bias=False),
            'w_hh' : nn.Linear(hidden_size, hidden_size)
        })])

        # Hidden layers
        for _ in range(num_layers - 1):
            self.layers.append(nn.ModuleDict({
                'w_ih' : nn.Linear(hidden_size, hidden_size, bias=False),
                'w_hh' : nn.Linear(hidden_size, hidden_size)
            }))

        # Activation function
        self.act = nn.Tanh()

    def forward(self, x, h=None):
        """
        In deep rnns (num_layers > 1), hidden state is propagated as:
        - input hidden state of the next layer of the current timestep; h(l, t) -> h(l+1, t)
        - input hidden state of the current layer of the next timestep; h(l, t) -> h(l, t+1)
        """
        x = x.transpose(0, 1) # [batch_size, seq_len, input_size] -> [seq_len, batch_size, input_size]
        seq_len, batch_size, _ = x.size()

        if h is None: # by default, hidden state is initialized to 0s
            h = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=x.device, dtype=x.dtype)

        output = []
        for t in range(seq_len):
            x_t = x[t] # input
            for l, layer in enumerate(self.layers):
                # h[l] here is the lth hidden state for t-1
                # when override it will be the lth hidden state for t
                h[l] = self.act(
                    layer['w_ih'](x_t) + layer['w_hh'](h[l])
                )

                x_t = h[l] # also, it will be the input for the l+1th layer of t
            
            output.append(h[-1]) # final hidden state for timestep t
        
        return torch.stack(output).transpose(0, 1) # [seq_len, batch_size, input_size] -> [batch_size, seq_len, hidden_size]

In [5]:
input_size = 3
hidden_size = 128
num_layers = 3

batch_size = 2
seq_len = 10


model = RNN(input_size, hidden_size, num_layers=num_layers)

x = torch.randn(batch_size, seq_len, input_size)

x.size(), model(x).size()

(torch.Size([2, 10, 3]), torch.Size([2, 10, 128]))