In [1]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35

# .load_data_time_machine() is implemented in language-model.ipynb
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

In [3]:
# We use one_hot encoding
# Assume we have len(vocab) = 3,
# then 0 = [1, 0, 0], 1 = [0, 1, 0], 2 = [0, 0, 1]
F.one_hot(torch.tensor([0, 2]), len(vocab))

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]])

In [6]:
# This function is to initialize parameters of RNN
def get_params(vocab_size, num_hiddens, device):
    """
    Args:
    - vocab_size: length of our input vocabulary list. eg: ["I", "like", "fruit"] has a length of 3
    - num_hiddens: number of hidden neurons
    - device: CPU or GPU

    Returns:
    [W_xh, W_hh, b_h, W_hq, b_q]
    """
    num_inputs = num_outputs = vocab_size 

    # This is a nested function for parameter initialization 
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    W_xh = normal((num_inputs, num_hiddens))  # Weight vector from input layer to hidden layer
    W_hh = normal((num_hiddens, num_hiddens)) # W V from previous hidden layer to current hidden layer
    b_h = torch.zero(num_hiddens, device=device) # bias vector -- y_hidden = W_xh @ X_input + W_hh @ previous_hidden + b_h
    W_hq = normal((num_hiddens, num_outputs)) # W V from hidden layer to output layer
    b_q = torch.zeros(num_outputs, device=device) # bias vector used for the output
    params = [W_xh, W_hh, b_h, W_hq, b_q]

    # Set .requires_grad_(True) for training
    for param in params:
        param.requires_grad_(True)
    return params  

In [4]:
# Define the Model function RNN
def rnn(inputs, state, params):
    """
    Implements a simple Recurrent Neural Network (RNN) forward pass.

    Args:
    - inputs: List of input tensors, each with shape (n, h), where:
        n = batch size (number of samples processed in parallel)
        h = input feature dimension
    - state: Tuple containing the initial hidden state, (H,), where:
        H has shape (n, h) and represents the hidden state of the RNN
    - params: Tuple containing model parameters:
        W_xh: Weight matrix for input-to-hidden connection (shape: h x h)
        W_hh: Weight matrix for hidden-to-hidden transition (shape: h x h)
        b_h: Bias vector for hidden state computation (shape: 1 x h)
        W_hq: Weight matrix for hidden-to-output transformation (shape: h x q)
        b_q: Bias vector for output computation (shape: 1 x q)

    Returns:
    - outputs: Tensor of shape (a * n, q), where:
        a = number of time steps (length of inputs list)
        n = batch size
        q = output dimension
    - final_state: Tuple containing the final hidden state (H,), where H has shape (n, h)
    """
    
    # Unpack model parameters
    W_xh, W_hh, b_h, W_hq, b_q = params

    H, = state  # H has shape (n, h)
    # Example inputs:
    # inputs: a list of 3 tensors (time steps), each of shape (n=2, h=5)
    # state: (torch.zeros(n=2, h=10)) -> shape (2,10)
    # Example parameter shapes:
    # W_xh: (5,10), W_hh: (10,10), b_h: (1,10)
    # W_hq: (10,5), b_q: (1,5)

    # Example input sequence:
    # Suppose vocab_size = 5
    # inputs = [
    #    torch.tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]),  # Time step 1
    #    torch.tensor([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]),  # Time step 2
    #    torch.tensor([[0, 0, 1, 0, 0], [0, 0, 0, 1, 0]])   # Time step 3
    # ]  # Shape: (3 time steps, 2 batches, 5 input features)

    # List to store output tensors at each time step
    outputs = []

    # Iterate over the input sequence 
    for X in inputs:
        """
        Compute the new hidden state:
        H = tanh(X @ W_xh + H @ W_hh + b_h)
        
        Explanation:
        - X @ W_xh: Projects input X into the hidden space (shape: n x h)
        - H @ W_hh: Applies transformation to the previous hidden state (shape: n x h)
        - b_h: Bias term added to the hidden state (shape: 1 x h, broadcasted to n x h)
        - tanh: Non-linear activation function to introduce non-linearity
        """

        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)  # H shape: (n, h)

        """
        Compute the output Y at the current time step:
        Y = H @ W_hq + b_q
        
        Explanation:
        - H @ W_hq: Maps hidden state to output space (shape: n x q)
        - b_q: Bias term for output transformation (shape: 1 x q, broadcasted to n x q)
        """
        Y = torch.mm(H, W_hq) + b_q  # Y shape: (n, q)

        # Store the output for this time step
        outputs.append(Y)  # List of shape (a elements, each of shape n x q)

    """
    Concatenate all output tensors along the first dimension:
    - Since outputs is a list of a tensors, each of shape (n, q), we stack them
    - The resulting tensor has shape (a * n, q), effectively flattening time dimension
    """
    return torch.cat(outputs, dim=0), (H,)  # Final output shape: (a * n, q)

In [None]:
def init_rnn_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

In [None]:
class RNNModel:
    def __init__(self, vocab_size, num_hiddens, device, get_params, init_state, forward_fn):
        """
        Initializes the RNN model.

        Args:
        - vocab_size (int): Size of the vocabulary (number of unique tokens).
        - num_hiddens (int): Number of hidden units in the RNN.
        - device (torch.device): The device (CPU/GPU) where computations will be performed.
        - get_params (function): Function that initializes model parameters (weights & biases).
        - init_state (function): Function to initialize the hidden state.
        - forward_fn (function): Function implementing the forward pass of the RNN.

        Example:
        >>> model = RNNModel(vocab_size=5, num_hiddens=10, device=torch.device('cpu'), 
        ...                  get_params=get_params, init_state=init_state, forward_fn=rnn)

        This initializes:
        - `params`: Weight matrices and biases used in the RNN.
        - `init_state`: Function for initializing the hidden state.
        - `forward_fn`: Function that performs the RNN computation.
        """

        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens  

        # Initialize parameters (W_xh, W_hh, b_h, W_hq, b_q)
        self.params = get_params(vocab_size, num_hiddens, device)

        # Store hidden state initialization and forward pass function
        self.init_state, self.forward_fn = init_state, forward_fn
    
    # This function is called when RNNModel gets instantiated, eg. RNN = RNNModel(...)
    def __call__(self, X, state):
        """
        Runs the forward pass of the RNN.

        Args:
        - X (torch.Tensor): Tensor of shape (batch_size, sequence_length), 
                            containing token indices for each batch.
        - state (tuple): Tuple containing the hidden state(s).

        Returns:
        - output (torch.Tensor): Tensor of shape (sequence_length * batch_size, vocab_size),
                                 representing the predicted output for each input token.
        - updated state (tuple): The hidden state after processing the input.

        Example:
        X = torch.tensor([[0, 1, 2], [3, 4, 0]])  # Shape: (batch_size=2, sequence_length=3)
        state = model.begin_state(batch_size=2, device=torch.device('cpu'))
        output, new_state = model(X, state)

        Processing:
        One-hot Encoding: Convert token indices into one-hot vectors.
           - Input X shape: (batch_size=2, sequence_length=3)
           - Output shape after one-hot encoding: (sequence_length=3, batch_size=2, vocab_size=5)
        Forward Pass: Calls forward_fn(X, state, params), which runs the RNN computation.
        Outputs and New State: Returns the predicted outputs and updated hidden state.

        Example:
        vocab = ["hello", "world", "I", "love", "AI"]
        word_to_idx = {"hello": 0, "world": 1, "I": 2, "love": 3, "AI": 4}
        X = torch.tensor([[2, 3, 4],  # "I love AI"
                         [0, 1, 2]]) # "hello world I"
        """

        

        # Example input:
        # Suppose vocab_size = 5, and X = [[2, 3, 4], [0, 1, 2]]
        # The token indices map to:
        # "I" → 2, "love" → 3, "AI" → 4, "hello" → 0, "world" → 1
        #
        # One-hot encoding transforms indices into binary vectors:
        # "I"    → [0, 0, 1, 0, 0]
        # "love" → [0, 0, 0, 1, 0]
        # "AI"   → [0, 0, 0, 0, 1]
        # "hello"→ [1, 0, 0, 0, 0]
        # "world"→ [0, 1, 0, 0, 0]
        #
        # So the transformed X (after one-hot) has shape:
        # (sequence_length=3, batch_size=2, vocab_size=5)
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)  
        # Example transformation:
        # If vocab_size = 5, and X = [[0,1,2], [3,4,0]]
        # One-hot encoded X has shape (sequence_length=3, batch_size=2, vocab_size=5)

        # Perform forward pass using the RNN function
        return self.forward_fn(X, state, self.params)
    
    def begin_state(self, batch_size, device):
        """
        Initializes the hidden state for the RNN.
        Returns:
        - Initial hidden state of shape (batch_size, num_hiddens).

        Example:
        state = model.begin_state(batch_size=2, device=torch.device('cpu'))
        print(state[0].shape)  # Expected: (2, num_hiddens=10)
        """
        return self.init_state(batch_size, self.num_hiddens, device)
