In [1]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35

# .load_data_time_machine() is implemented in language-model.ipynb
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

In [3]:
# We use one_hot encoding
# Assume we have len(vocab) = 3,
# then 0 = [1, 0, 0], 1 = [0, 1, 0], 2 = [0, 0, 1]
F.one_hot(torch.tensor([0, 2]), len(vocab))

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]])

In [6]:
# This function is to initialize parameters of RNN
def get_params(vocab_size, num_hiddens, device):
    """
    Args:
    - vocab_size: length of our input vocabulary list. eg: ["I", "like", "fruit"] has a length of 3
    - num_hiddens: number of hidden neurons
    - device: CPU or GPU

    Returns:
    [W_xh, W_hh, b_h, W_hq, b_q]
    """
    num_inputs = num_outputs = vocab_size 

    # This is a nested function for parameter initialization 
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    W_xh = normal((num_inputs), num_hiddens)  # Weight vector from input layer to hidden layer
    W_hh = normal((num_hiddens, num_hiddens)) # W V from previous hidden layer to current hidden layer
    b_h = torch.zero(num_hiddens, device=device) # bias vector -- y_hidden = W_xh @ X_input + W_hh @ previous_hidden + b_h
    W_hq = normal((num_hiddens, num_outputs)) # W V from hidden layer to output layer
    b_q = torch.zeros(num_outputs, device=device) # bias vector used for the output
    params = [W_xh, W_hh, b_h, W_hq, b_q]

    # Set .requires_grad_(True) for training
    for param in params:
        param.requires_grad_(True)
    return params  

In [4]:
# Define the Model function RNN
def rnn(inputs, state, params):
    """
    Implements a simple Recurrent Neural Network (RNN) forward pass.

    Args:
    - inputs: List of input tensors, each with shape (n, h), where:
        n = batch size (number of samples processed in parallel)
        h = input feature dimension
    - state: Tuple containing the initial hidden state, (H,), where:
        H has shape (n, h) and represents the hidden state of the RNN
    - params: Tuple containing model parameters:
        W_xh: Weight matrix for input-to-hidden connection (shape: h x h)
        W_hh: Weight matrix for hidden-to-hidden transition (shape: h x h)
        b_h: Bias vector for hidden state computation (shape: 1 x h)
        W_hq: Weight matrix for hidden-to-output transformation (shape: h x q)
        b_q: Bias vector for output computation (shape: 1 x q)

    Returns:
    - outputs: Tensor of shape (a * n, q), where:
        a = number of time steps (length of inputs list)
        n = batch size
        q = output dimension
    - final_state: Tuple containing the final hidden state (H,), where H has shape (n, h)
    """
    
    # Unpack model parameters
    W_xh, W_hh, b_h, W_hq, b_q = params

    # Unpack the initial hidden state
    H, = state  # H has shape (n, h)

    # List to store output tensors at each time step
    outputs = []

    # Iterate over the input sequence (loop over time steps)
    for X in inputs:
        """
        Compute the new hidden state:
        H = tanh(X @ W_xh + H @ W_hh + b_h)
        
        Explanation:
        - X @ W_xh: Projects input X into the hidden space (shape: n x h)
        - H @ W_hh: Applies transformation to the previous hidden state (shape: n x h)
        - b_h: Bias term added to the hidden state (shape: 1 x h, broadcasted to n x h)
        - tanh: Non-linear activation function to introduce non-linearity
        """
        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)  # H shape: (n, h)

        """
        Compute the output Y at the current time step:
        Y = H @ W_hq + b_q
        
        Explanation:
        - H @ W_hq: Maps hidden state to output space (shape: n x q)
        - b_q: Bias term for output transformation (shape: 1 x q, broadcasted to n x q)
        """
        Y = torch.mm(H, W_hq) + b_q  # Y shape: (n, q)

        # Store the output for this time step
        outputs.append(Y)  # List of shape (a elements, each of shape n x q)

    """
    Concatenate all output tensors along the first dimension:
    - Since outputs is a list of a tensors, each of shape (n, q), we stack them
    - The resulting tensor has shape (a * n, q), effectively flattening time dimension
    """
    return torch.cat(outputs, dim=0), (H,)  # Final output shape: (a * n, q

In [None]:
class RNNModel:
    def __init__(self, vocab_size, num_hiddens, device, get_params, init_state, forward_fn):
        """
        Initializes the RNN model.

        Args:
        - vocab_size: Integer, the size of the vocabulary (number of unique tokens).
        - num_hiddens: Integer, the number of hidden units in the RNN.
        - device: The device (CPU/GPU) on which computations should be performed.
        - get_params: Function that returns the model parameters (weights and biases).
        - init_state: Function that initializes the hidden state.
        - forward_fn: Function that implements the forward pass of the RNN.

        Returns:
        - Initializes the RNN model with parameters, state initialization, and forward function.
        """

        # Store hyperparameters
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens  

        # Obtain the model parameters (e.g., weight matrices and biases)
        self.params = get_params(vocab_size, num_hiddens, device)

        # Store the functions for state initialization and forward propagation
        self.init_state, self.forward_fn = init_state, forward_fn
    
    # This function is called when the class get instantiated, for example. RNN = RNNModel(param1,...) will call the function below by default
    def __call__(self, X, state):
        """
        Defines how the model processes input X and updates its hidden state.

        Args:
        - X: Tensor of shape (batch_size, sequence_length), representing a batch of tokenized sequences.
        - state: Tuple containing the hidden state(s).

        Returns:
        - output: The output tensor produced by the RNN.
        - updated state: The hidden state after processing the input.

        Steps:
        1. Convert input X into a one-hot representation:
           - Shape transformation: (batch_size, sequence_length) → (sequence_length, batch_size, vocab_size)
        2. Convert the one-hot encoding to float type for matrix operations.
        3. Call `forward_fn` to process the input through the RNN and get the output and new state.
        """
        
        # Convert input indices into a one-hot encoded representation
        # X.T is used to swap batch_size and sequence_length dimensions
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)  
        # Shape after one-hot encoding: (sequence_length, batch_size, vocab_size)

        # Perform forward pass using the RNN function
        return self.forward_fn(X, state, self.params)
    
    def begin_state(self, batch_size, device):
        """
        Initializes the hidden state for the RNN.

        Args:
        - batch_size: Integer, the number of sequences in the batch.
        - device: The device (CPU/GPU) where computations will run.

        Returns:
        - Initial hidden state of the RNN.
        """
        return self.init_state(batch_size, self.num_hiddens, device)
