# Homework 6

In this homework, you will train and experiment with a "char-RNN" -— a language model that predicts the next character in a sequence (see [this famous blog post by Andrej Karpathy](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)). Before you start on the rest of the homework, please give the blog post a read! Unlike Karpathy’s original implementation, you will be building your own version using modern PyTorch modules. Rather than relying on nn.RNN or nn.GRU directly, you will implement the recurrent computations manually to see exactly how hidden states are updated over time.

For this homework, we are going to implement the following:

**Data class**
1. Vocab: builds a character-level vocabulary for encoding and decoding the given text.
1. CharDataset: cleans the raw text and converts it into sequences of fixed length, using the Vocab class for tokenization.

**Model class**
1. RNNScratch: a “vanilla” RNN implementation that performs hidden state updates.
1. GRUScratch: A gated recurrent unit (GRU) version that adds gating mechanisms.
1. RNNLMScratch: a wrapper that combines the RNN cell (either vanilla or GRU) with an output layer for predicting the next character.
1. Trainer: a simple class for handling training loop, optimization, loss recording, and evaluation.

For RNNScratch and GRUScratch, compare their final training/validation loss and the quality of text samples generated. Which one has the higher or lower final training loss? Are the samples from one model more or less realistic?

In [None]:
import re
import torch
import random
import collections
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import matplotlib.pyplot as plt
from datasets import load_dataset
from typing import Dict, Tuple, List

import os
os.environ["HF_TOKEN"] = ""

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
######################## DATA LOADING (DO NOT DELETE THE CELL) ##########################
if __name__ == '__main__':

    ds = load_dataset("r-three/shakespeare-sonnet-dialogue-blob")
    shakes_ds = ' '.join(ds['train']['text'])

    print(shakes_ds[:499])
    print("\n")

    print('shakes distinct characters: {}'.format(set(shakes_ds)))
    print('shakes total # characters: {}'.format(len(set(shakes_ds))))

## Part 1: Text Preprocessing (1.5 points)

To train a language model, we first need to convert raw text into a numerical form that the RNN can process. In this homework for building a char-RNN, each individual character is treated as a token. For example, the sentence "good morning" consists of a 12-token sequence (including the gap), ['g','o','o','d',' ','m','o','r','n','i','n','g']. Each character is mapped to a unique integer.


**TODO**
1. Complete the Vocab class: builds a mapping between each character (token) and its integer index.
2. Complete the CharDataset class: processes the raw text using the Vocab class; converts the text into sequences of fixed length (num_steps) to form training examples. This can be wrapped by a PyTorch DataLoader to generate mini-batches during training.

Usage:

```
vocab = Vocab(tokens=['a','b','c'])
a_index = vocab['a']
a_index_tok = vocab.to_tokens(a_index)
'a' == a_index_tok # this should return True

mydataset = CharDataset(num_steps, raw_text)
loader = DataLoader(mydataset, batch_size=4)
for b in loader:
    # b is a batch of size 4
    ...
```

### Part 1.a: Vocab class

Important consideration when designing Vocab
* unk() function returns the index for `<unk>` token, which denotes an unknown token.
* In your Vocab class, treat any token with frequency count of less than min_freq as `<unk>` token. Hint: you can use collections.Counter
* This allows the model using the Vocab class to handle unseen token during inference time.

In [None]:
class Vocab:
    """Vocabulary for raw text data"""
    def __init__(self, tokens=[], min_freq=1):
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]

        ###################### YOUR CODE ####################
        # TODO: Map a token to a specific index.
        # If the token occurrence is less than min_freq, exclude it from the list of tokens
        self.idx_to_token = None
        self.token_to_idx = {}
        #####################################################

    def __len__(self):
        ###################### YOUR CODE ####################
        # TODO: Return the total token count
        return 0
        #####################################################

    def __getitem__(self, tokens: List[str] | str) -> List[int] | int:
        ###################### YOUR CODE ####################
        # TODO: Return the corresponding index or list of indices for the given tokens
        # If input is List[str], return List[int]. If input is str, return int.
        return []
        #####################################################

    def to_tokens(self, indices: List[int] | int) -> List[str] | str:
        ###################### YOUR CODE ####################
        # TODO: Given the index or list of indices, map the index back to token(s)
        # If input is List[int], return List[str]. If input is int, return str.
        return ''
        #####################################################

    @property
    def unk(self):  # Index for the unknown token
        return self.token_to_idx['<unk>']

In [None]:
######## SIMPLE CHECK ########
voc = Vocab(tokens=['a','b','c','d','d'], min_freq=2)
print(voc.token_to_idx)
# should be True
print('d' == voc.to_tokens(voc['d']))

### Part 1.b: CharDataset class

Now that you have implemented the Vocab class to map characters to indices, we will create a custom Dataset class extends torch.utils.data.Dataset. The core functionality of the PyTorch Dataset is to provide a way to access individual data samples and their corresponding labels. Dataset works in conjunction with a DataLoader, which handles batching, shuffling, and data loading to feed samples to the model during training.

In our case, CharDataset will construct the input and target sequences, X and Y, where Y is simply X shifted one position to the right.

For instance, given a raw text "hi, how are you?", `_preprocess` and `_tokenize` will clean this text and tokenize it (['h' 'i' ',' 'h' 'o' 'w' ' ' ...]) to the corresponding indices ([1 2 3 4 5 6 7...]). If we set num_steps = 3, the resulting input (X) and target sequences (Y) for the character-level RNN would look like this:
```
X:
1 2 3
2 3 4
3 4 5
4 5 6
...

Y:
2 3 4
3 4 5
4 5 6
5 6 7
...
```

**NOTE**: In language modeling, a common practice is to mark the beginning or end of sentences with special tokens such as `<bos>` (beginning of sentence) and `<eos>` (end of sentence). You may also choose to clean the text in other ways depending on your needs. In our case, you can implement your own _preprocess logic to handle any text cleaning or preprocessing logic before buidling the input and target sequences.

In [None]:
class CharDataset(Dataset):
    """Builds the dataset class handling raw text processing"""
    def __init__(self, num_steps, raw_text):
        self.num_steps = num_steps

        self.tokens = self._tokenize(self._preprocess(raw_text))
        self.vocab = Vocab(self.tokens)
        self.corpus = [self.vocab[token] for token in self.tokens]  # convert tokens to corresponding indices

        ########################## YOUR CODE ##########################
        # TODO: define X and Y input data using corpus
        # X is the input sequence and Y is the corresponding labels.
        # X, Y should be torch.tensor()
        self.X, self.Y = None, None
        ###############################################################

    def _preprocess(self, text: str):
        ########################## YOUR CODE ##########################
        # TODO: lowercase the alphabets, and do additional cleaning of the raw text as you want.
        # For instance, you can consider replacing the non-alphabets with white space (or not).
        return text
        ###############################################################

    def _tokenize(self, text: str) -> List[str]:
        ########################## YOUR CODE ##########################
        # TODO: Given an input text, return a list of tokens
        # e.g. "hello world" -> ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd']
        return ['']
        ###############################################################

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

    def __len__(self):
        return len(self.X)


## Problem 2: Build RNN/GRU model from scratch (3.5 points)

Now that you have built the data processing classes to clean the raw text corpus and prepare batches, you will build your own RNN modules from scratch. In this section, you will manually implement both a vanilla RNN and a GRU, along with a simple language modeling wrapper for training the actual RNN-based language model.

TODO:
1. Complete the **RNNScratch**: A single vanilla RNN layer.
2. Complete the **GRUScratch**: A single GRU layer.
3. Complete the **RNNLMScratch**: language model wrapper that connects RNN/GRU cell to a embedding lyaer and output layer.

NOTE:
* Throughout Part 2, make sure all tensors are on the correct device by explicitly using .to(device) when needed.

### 2.a Vanilla RNN from scratch

As a reminder, a vanilla RNN class's hidden state at timestep t is computed as:
$$\begin{aligned}
h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)
\end{aligned}
$$

**Parameter initialization**:

To initialize the parameters, use nn.Parameter(). We recommend using Glorot initialization to initialize the weights. You can also experiment with other weight initialization methods!

**NOTE**: the convention we are following for this homework uses (num_hiddens, num_inputs) for the shape of $W_{xh}$.

In [None]:
class RNNScratch(nn.Module):
    """The RNN model implemented from scratch."""
    def __init__(self, num_inputs, num_hiddens):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.num_inputs = num_inputs
        ########################## YOUR CODE ##########################
        # Hint: use nn.Parameter()
        self.W_xh = None
        self.W_hh = None
        self.b_h = None
        ################################################################

    def forward(self, inputs, state=None):
        outputs = []
        ########################## YOUR CODE ##########################
        # RNN updates the hidden state one timestep at a time.
        # Iterate over the first dimension of inputs (time steps),
        # then update the state iteratively
        # Make sure the initial hidden state has the correct shape: (batch_size, num_hiddens).
        # Shape of inputs: (num_steps, batch_size, num_inputs)
        ################################################################
        return outputs, state

In [None]:
###################### SANITY CHECK ######################
def check_len(a, n):
    """Check the length of a list."""
    assert len(a) == n, f'list\'s length {len(a)} != expected length {n}'

def check_shape(a, shape):
    """Check the shape of a tensor."""
    assert a.shape == shape, \
            f'tensor\'s shape {a.shape} != expected shape {shape}'

batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens).to(device)
X = torch.ones((num_steps, batch_size, num_inputs)).to(device)
outputs, state = rnn(X)

check_len(outputs, num_steps)
check_shape(outputs[0], (batch_size, num_hiddens))
check_shape(state, (batch_size, num_hiddens))
###########################################################

### 2.b GRU from scratch

GRU (Gated Recurrent Unit) is a type of recurrent neural network designed to better handle long-term dependencies and reduce the vanishing gradient problem affecting vanilla RNNs. It does so by introducing update gate and reset gate, which control how much past information to keep and how much new information to add at each time step. A hidden state $h_t$ at timestep t is computed as:

$$
\begin{aligned}
z_t &= \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z) && \text{(update gate)} \\
r_t &= \sigma(W_{xr} x_t + W_{hr} h_{t-1} + b_r) && \text{(reset gate)} \\
\tilde{h}_t &= \tanh(W_{xh} x_t + W_{hh} (r_t \odot h_{t-1}) + b_h) && \text{(candidate state)} \\
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t && \text{(new hidden state)}
\end{aligned}
$$

where $\sigma$ is sigmoid activation, $\odot$ an element-wise multiplication.

At a high-level, the **update gate** ($z_t$) controls how much of the previous state to keep: $(1 - z_t) \odot h_{t-1}$ retains part of the past memory, and $z_t \odot \tilde{h}_t$ adds new information from the new state. The **reset gate** ($r_t$) controls how much of the previous hidden state $h_{t-1}$ should influence the new candidate state $\tilde{h}_t$. Notice that $r_{t}$ = 1 is equivalent to the hidden state update of vanilla RNN, whereas $r_t$ close to 0 down-weighs the influence of the previous hidden state, thereby "resetting" the memory.

*To think about: Why is GRU better for handling long-term dependency? Why does it help with vanishing gradient problem?*

**NOTE**: the convention we are following for this homework uses (num_hiddens, num_inputs) for the shape of $W_{xh}$, $W_{xr}$, $W_{xz}$.

In [None]:
class GRUScratch(nn.Module):
    def __init__(self, num_inputs, num_hiddens):
        super().__init__()

        self.num_hiddens = num_hiddens
        self.num_inputs = num_inputs
        ########################## YOUR CODE ##########################
        self.W_xz, self.W_hz, self.b_z = None, None, None  # Update gate
        self.W_xr, self.W_hr, self.b_r = None, None, None  # Reset gate
        self.W_xh, self.W_hh, self.b_h = None, None, None  # Candidate hidden state
        ################################################################

    def forward(self, inputs, state=None):
        outputs=[]
        ########################## YOUR CODE ##########################
        ################################################################
        return outputs, state


In [None]:
###################### SANITY CHECK ######################
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
gru = GRUScratch(num_inputs, num_hiddens).to(device)
X = torch.ones((num_steps, batch_size, num_inputs)).to(device)
outputs, state = gru(X)

check_len(outputs, num_steps)
check_shape(outputs[0], (batch_size, num_hiddens))
check_shape(state, (batch_size, num_hiddens))
###########################################################

### 2.c RNN LM wrapper

Now we implement RNN LM wrapper which takes RNNScratch or GRUScratch class we built in the previous sections and adds embedding and output layers. This class handles input embedding (one_hot()), the end-to-end forward pass (forward()), and prediction given a prefix (predict()).

NOTE:
* predict() should generates text tokens autoregressively starting from a given prefix. During the initial warm-up period, the model is fed the ground-truth prefix tokens instead of its own predictions, which helps the RNN build a hidden state representation of the prefix before it begins free generation.

In [None]:
class RNNLMScratch(nn.Module):
    """The RNN-based language model implemented from scratch."""
    def __init__(self, rnn, vocab_size, lr):
        super().__init__()
        self.rnn = rnn
        self.vocab_size = vocab_size
        self.lr = lr
        self.train_loss = []
        self.valid_loss = []

        ########################## YOUR CODE #########################
        # TODO: initialize the embedding layer / output layer
        self.W_hq = None
        self.b_q = None
        ###############################################################

    def one_hot(self, X):
        ########################## YOUR CODE #########################
        # Input shape : (batch_size, num_steps)
        # Output shape: (num_steps, batch_size, vocab_size)
        # NOTE: remember to set the torch dtype to torch.float32
        #       to maintain the same data type as the original X
        return X
        ###############################################################

    def forward(self, X, state=None):
        ########################## YOUR CODE #########################
        # TODO: Apply the one-hot embedding,
        #       pass the embedding through RNN,
        #       return the rnn outputs and final rnn state after applying embedding and bias
        ###############################################################
        return None, None

    def predict(self, prefix, num_preds, vocab, device=None):

        state, outputs = None, [vocab[prefix[0]]]
        for t in range(len(prefix) + num_preds - 1):
            ########################## YOUR CODE #########################
            # TODO: Given the prefix, make num_preds many following predictions
            # Implement a separate logic for the initial warm-up period and
            # the actual model generation phase
            pass
            ###############################################################
        return ''.join([vocab.idx_to_token[i] for i in outputs])

    def loss_fn(self, y_hat, Y):
        return F.cross_entropy(y_hat.reshape(-1, y_hat.shape[-1]), Y.reshape(-1))

## 3. Simple trainer for model training (1.5 points)

Next, you will implement a lightweight Trainer class to train the model. While popular libraries such as HuggingFace’s Trainer automate this process, you will build your own simplified version with core functionalities.

TODO:
* Complete the **Trainer**: the trainer class should manage the training and validation loops, handle data loading, and track training and validation losses.
* Train your model until it achieves a validation loss below 1.8 on the dataset within 10 epochs.

Note:
* `clip_gradients` is provided to keep the gradients from growing too large. Consider using this during your `Trainer.fit` function.

In [None]:
class Trainer:
    def __init__(self, max_epochs, batch_size, device, gradient_clip_val=1):
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.gradient_clip_val = gradient_clip_val
        self.device = device
        self.train_loss = [] # record the avg. batch loss every epoch
        self.valid_loss = [] # record the avg. batch loss every epoch

    @staticmethod
    def clip_gradients(model, max_norm):
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

    def get_dataloader(self, data):
        train_size = int(0.8 * len(data))
        train_data, val_data = random_split(data, [train_size, len(data) - train_size])
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        valid_loader = DataLoader(val_data, batch_size=self.batch_size)

        return train_loader, valid_loader

    def fit(self, model, data, optimizer=None):
        model.to(self.device)
        if optimizer is None:
            optimizer = torch.optim.SGD(model.parameters(), lr=model.lr)
        train_loader, valid_loader = self.get_dataloader(data)

        for epoch in range(self.max_epochs):
            model.train()
            train_loss = 0
            valid_loss = 0
            ########################### YOUR CODE ###################################
            # TODO: Train the model for max_epochs many steps
            # Complete a single forward and backward pass on a given training batch
            # Record the training loss
            ########################################################################
            self.train_loss.append(train_loss / len(train_loader))

            model.eval()
            with torch.no_grad():
                ########################### YOUR CODE ###################################
                # TODO: at the end of each epoch, evaluate the model on the validation set.
                # Complete a single forward pass on a given validation batch
                # Record the validation loss
                pass
                ########################################################################
            self.valid_loss.append(valid_loss / len(valid_loader))

            print(f"Epoch {epoch+1} train loss: {self.train_loss[-1]}, validation loss {self.valid_loss[-1]}")

### Verify your results!

Note: this won't run if part 1, part 2 are not completed.

In [None]:
################# DO NOT CHANGE #################
def train_rnnlm(data, rnn_class, num_hiddens, lr, batch_size, num_epochs=10):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    rnn = rnn_class(num_inputs=len(data.vocab), num_hiddens=num_hiddens).to(device)
    model = RNNLMScratch(rnn, vocab_size=len(data.vocab), lr=lr).to(device)
    trainer = Trainer(batch_size=batch_size, max_epochs=num_epochs, gradient_clip_val=1, device=device)
    trainer.fit(model, data)
    return trainer, model

In [None]:
################ TO SUBMIT ################
# Feel free to tweak the hyperparameters to achieve a better loss!
RNN_HYPERPARAM = {'num_hiddens': 64, 'lr': 2, 'batch_size': 128}
GRU_HYPERPARAM = {'num_hiddens': 64, 'lr': 3, 'batch_size': 128}

In [None]:
################# CHECK YOUR RESULT! #################
if __name__ == "__main__":
    seed=123
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)

    shakes_data = CharDataset(num_steps=32, raw_text=shakes_ds[:50000])
    rnn_trainer, rnn_model = train_rnnlm(data=shakes_data,
                                         rnn_class=RNNScratch,
                                         num_hiddens=RNN_HYPERPARAM['num_hiddens'],
                                         lr=RNN_HYPERPARAM['lr'],
                                         batch_size=RNN_HYPERPARAM['batch_size'],
                                         num_epochs=10
                                         )
    print('\n')
    gru_trainer, gru_model = train_rnnlm(data=shakes_data,
                                         rnn_class=GRUScratch,
                                         num_hiddens=GRU_HYPERPARAM['num_hiddens'],
                                         lr=GRU_HYPERPARAM['lr'],
                                         batch_size=GRU_HYPERPARAM['batch_size'],
                                         num_epochs=10
                                         )

    inputs = ['the mind in',
              'he saw a fox ',
              'thou art ',
              'to be or not ',
              'she stood on ',
              'before the start ',
              'thou speakest']

    print("\n")
    print("RNN predictions:")
    for i in range(len(inputs)):
        print(rnn_model.predict(inputs[i], 30, shakes_data.vocab, device=device))

    print('\n')
    print("GRU predictions:")
    for i in range(len(inputs)):
        print(gru_model.predict(inputs[i], 30, shakes_data.vocab, device=device))

In [None]:
if __name__ == "__main__":
    f, ax = plt.subplots(ncols=2, figsize=(8,4))
    ax[0].plot(rnn_trainer.train_loss, label='rnn')
    ax[0].plot(gru_trainer.train_loss, label='gru')
    ax[0].set_title("Train loss")
    ax[1].plot(rnn_trainer.valid_loss, label='rnn')
    ax[1].plot(gru_trainer.valid_loss, label='gru')
    ax[1].set_title("Validation loss")
    ax[0].legend()
    ax[1].legend()

# Collaboration / External Help
Disclose any help you used (LLM usage, blogs, search, Github links, etc) and collaborations with your classmates. If you  completed the homework on your own, you can leave this part empty.

> TODO