In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import matmul

In [None]:
class ElmanRNNUnit(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.Uh = nn.Parameter(torch.randn(embedding_dim, embedding_dim)) # Note: Not nn.Linear
        self.Wh = nn.Parameter(torch.randn(embedding_dim, embedding_dim)) # weight matrices for the hidden state and input vector with random values
        self.b = nn.Parameter(torch.zeros(embedding_dim)) # bias vector to zero
    
    def forward(self, x, h):
        # x: (batch_size, embedding_dim)
        # h: (batch_size, embedding_dim)
        return torch.tanh(matmul(x, self.Uh) + matmul(h, self.Wh) + self.b)

In [None]:
class ElmanRNN(nn.Module):
    def __init__(self, emb_dim, num_layers):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_layers = num_layers
        self.rnn_units = nn.ModuleList([ElmanRNNUnit(emb_dim) for _ in range(num_layers)])
    
    def forward(self, x):
        # x: (batch_size, seq_len, embedding_dim)
        # x = document
        # seq_len = max num of words in a document
        # embedding_dim = size of the word embedding
        batch_size, seq_len, emb_dim = x.shape
        h_prev = [torch.zeros(batch_size, emb_dim, device=x.device) for _ in range(self.num_layers)]
        outputs = []
        for time_step in range(seq_len):
            x_t = x[:, time_step, :] # (batch_size, emb_dim)
            for layer_idx, rnn_unit in enumerate(self.rnn_units):
                h_new = rnn_unit(x_t, h_prev[layer_idx]) # Update hidden state
                h_prev[layer_idx] = h_new # Input for next layer
                input_t = h_new # Collect outputs
            outputs.append(input_t)
        return torch.stack(outputs, dim=1) # (batch_size, seq_len, emb_dim)