# Assignment 3

In this assignment, you are to experiment with embedding vectors of words and training of a recurrent neural network for sentence classification.

## 1. Loading dataset

The dataset comes from https://www.kaggle.com/datasets/amananandrai/ag-news-classification-dataset

It contains 120,000 news articles that are labeled into four categories:

- 1: world news
- 2: sports
- 3: business
- 4: science and technology

In [2]:
# [THIS IS READ-ONLY]
import torchtext.datasets
import pandas as pd

train_iter = torchtext.datasets.AG_NEWS(root='./datasets', split='train')

train_df = pd.DataFrame(
    data=list(iter(train_iter)),
    columns=['target', 'news'],
)

print("Five randomly selected samples:")
print(train_df.sample(5, random_state=0))

TypeError: _setup_datasets() got an unexpected keyword argument 'split'

## 2. Tokenizer

游뚿 Instruction:
> Load the `basic_english` tokenizer using the `get_tokenizer` from `torchtext.data`.

In [None]:
# [THIS IS READ-ONLY]
import torchtext.data

In [None]:
# [YOUR WORK HERE]
# @workUnit
from torchtext.data.utils import get_tokenizer

tokenizer = get_tokenizer("basic_english")

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: tokenizer

type(tokenizer), tokenizer.__qualname__

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: tokens of sentence
tokenizer("This is assignment 3 for csci 4050u.  It's on sequence learning.")

## 3. Vocabulary

Token sequence is a list of tokens.  We need to vocabulary to convert each
token into an integer, known as the token index.

In [None]:
# [THIS IS READ-ONLY]
# construct token sequence
# this is a collection of token sequences.
# Every sentence is converted to a token sequence by the tokenizer.

token_seq = map(tokenizer, train_df['news'])

游뚿 Instruction:

> Use the `build_vocab_from_iterator` helper function from `torchtext.vocab` to construct
the vocabulary from the `token_seq`.

> Make sure you set the `min_freq=5` and special tokens should be `['<unk>', '<s>']`.
The first token index `0` corresponds to unknown token `<unk>`.

In [None]:
# [THIS IS READ-ONLY]
import torchtext.vocab

In [None]:
# [YOUR WORK HERE]
# @workUnit
from torchtext.vocab import build_vocab_from_iterator


vocab = vocab = build_vocab_from_iterator(
    token_seq,  # token_seq is the iterator over tokenized sentences
    min_freq=5,  # Minimum frequency for tokens to be included in the vocabulary
    specials=['<unk>', '<s>']  # Special tokens: <unk> for unknown and <s> for sentence start
)

In [None]:
# [THIS IS READ-ONLY]
# if token is not in vocabulary, use the index 0.
vocab.set_default_index(0)

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: length of the vocab

len(vocab)

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: lookup token indexes using vocab

vocab.lookup_indices(tokenizer("this is an assignment for csci 4050u."))

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: lookup token string value using vocab

vocab.lookup_tokens([53, 22, 31, 10659, 12, 0, 0, 2])

## 4. Integer encoding

Now, we are ready to encode news article sentences into sequences of integers. 

游뚿 Instruction:

> create a list of `torch.int64` tensors.  Each of the tensor is a vector of int64 integers which are the token indexes of the tokens of sentences in the
> training data.

In [None]:
# [THIS IS READ-ONLY]
import torch

In [None]:
# [YOUR WORK HERE]
# @workUnit

index_sequences = [
    torch.tensor(vocab.lookup_indices(tokenizer(review)), dtype=torch.int64)
    for review in train_df['news']
]

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: return types

print(f"Type of index_sequences: {type(index_sequences)}")
print(f"Type of elements in index_sequences: {type(index_sequences[0])} with dtype {index_sequences[0].dtype}")

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: number of index sequences

len(index_sequences)

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: first three index sequences

for i in range(3):
    sentence = train_df.iloc[i].news
    index_sequence = index_sequences[i]
    print(sentence)
    print(index_sequence)

## 5. Prepare token index tensor

Now, we are ready to prepare the training and validation data.

In [None]:
# [THIS IS READ-ONLY]
from torch.nn.utils.rnn import pad_sequence

- First we will need to pad each sequence in index_sequences so
they are all match the *longest* sequence.

- Then, we wil truncate each sequence to keep only the first 100 tokens.
  This is to remove the noise of the few extra long articles.  Basically,
  we will classify the article using only the first 100 tokens.

In [None]:
# [THIS IS READ-ONLY]
padded_sequences = pad_sequence(index_sequences, batch_first=True)
print("After padding:", padded_sequences.shape)

padded_sequences = padded_sequences[:, :100]
print("After truncation:", padded_sequences.shape)

## 6. Prepare training and validation tensors

We can now prepare training and validation datasets for RNN training.

In [None]:
# [THIS IS READ-ONLY]
from torch.utils.data import (
    TensorDataset,
    random_split,
)

In [None]:
# [THIS IS READ-ONLY]
#
# targets
#

targets = torch.tensor(train_df['target'] - 1, dtype=torch.int64)
targets.shape

游뚿 Instructions:

- Create the dataset from `padded_sequences` and `targets` using `TensorData`
- Create training and validation dataset using `random_split`.  Use 30% of the dataset for validation.

In [None]:
# [YOUR WORK HERE]
# @workUnit

# IMPORTANT: keep this line to pass the checkpoints.
torch.manual_seed(0)

#
# dataset for training and validation
#

dataset = ...

(train_dataset, val_dataset) = ...

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: training and validation dataset sizes

len(train_dataset), len(val_dataset)

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: training sample

print("Training sample:")
print(train_dataset[0])

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: validation sample

print("Validation sample:")
print(val_dataset[0])

# INSTRUCTION

## 游닉 For the remainder of the worksheet, you must understand the code provided.  But no workUnits are required.

## 游뚿 You must execute all cells and obtain the performance comparison plots.

## 7. Simple RNN Module

In [None]:
# [THIS IS READ-ONLY]
import torch.nn as nn
from lightning.pytorch import LightningModule
from torchmetrics import Accuracy

vocab_size = len(vocab)
num_layers = 1
num_classes = 4

class MyRNN(nn.Module):
    def __init__(self, d_emb, d_state):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_emb)
        self.rnn = nn.RNN(
            input_size=d_emb,
            hidden_size=d_state,
            num_layers=num_layers,
            batch_first=True,
        )
        self.output = nn.Linear(d_state, num_classes)
        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)
            
    def forward(self, batch_of_sequences):
        embeddings = self.emb(batch_of_sequences)
        _, final_states = self.rnn(embeddings)
        final_state = final_states[-1]
        logits = self.output(final_state)
        return logits

Let's try out the basic RNN (not yet trained) on a sample batch.

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: untrained model checking

model = MyRNN(d_emb=128, d_state=64)
model

In [None]:
# [THIS IS READ-ONLY]
# @check
# @title: untrained model checking

model = MyRNN(d_emb=128, d_state=64)
x, target = dataset[:32]
model(x).shape

## 8. Simple RNN Lightning Module

Add the Lightning logging methods to `MyRNN`.

In [None]:
# [THIS IS READ-ONLY]
class MyLightning(LightningModule):        
    def training_step(self, batch_of_sequences):
        x, target = batch_of_sequences
        y = self.forward(x)
        loss = nn.functional.cross_entropy(y, target)
        self.accuracy(y, target)
        self.log('accuracy', self.accuracy, prog_bar=True)
        self.log('loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    
    def validation_step(self, batch, batch_index):
        x, target = batch
        y = self.forward(x)
        self.accuracy(y, target)
        self.log('val_acc', self.accuracy, prog_bar=True)

In [None]:
# [THIS IS READ-ONLY]
class MyLightningRNN(MyRNN, MyLightning):
    pass

## 9. Create a trainer utility

In [None]:
# [THIS IS READ-ONLY]
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning import seed_everything
from torch.utils.data import DataLoader
import shutil, os
import time

#
# initialize logger
#

batch_size = 32
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size)

def train(*, name:str, model:LightningModule, epochs:int, debug=True):
    # reset the random generator
    seed_everything(0)
    
    # create CSV logger
    logger = CSVLogger('./lightning_logs/', name)
        
    # create trainer
    trainer = Trainer(
        logger = logger,
        max_epochs = epochs,
        max_steps = 100 if debug else -1
    )
    
    try:
        shutil.rmtree(f"./lightning_logs/{name}")
        os.mkdirs(f"./lightning_logs/{name}")
    except:
        pass
    
    # start trainer
    start = time.time()
    trainer.fit(
            model=model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader
    )
    duration = (time.time() - start)
    print(f"Completed {epochs} epochs in {duration:0.2f} seconds.")
    print(trainer.validate(model, dataloaders = val_dataloader))

## 10. Train some RNN

游닉 Instruction

- You are encouraged to play with the parameters:

> - `d_emb`
> - `d_state`
> - `epochs`

游닉 Note:

- For `d_emb=8, d_state=16`, it takes 50 seconds per epoch.

In [None]:
# [YOUR WORK HERE]
# @workUnit

seed_everything(0)

train(
    name='rnn',
    model = MyLightningRNN(d_emb=8, d_state=16),
    epochs=5,
    debug=False,
)

We will now enhance the RNN classifier with a more advanced architecture for the cell -- namely the LSTM design.

### Extending RNN to LSTM 

In [None]:
# [THIS IS READ-ONLY]
class MyLSTM(nn.Module):
    def __init__(self, d_emb, d_state):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_emb)
        
        self.lstm = nn.LSTM(input_size=d_emb,
                          hidden_size=d_state,
                          num_layers=1,
                          batch_first=True)
        
        self.output = nn.Linear(d_state, num_classes)
        
        # will be monitoring accuracy
        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    
    def forward(self, x):
        x = self.embedding(x)
        _, (states, _) = self.lstm(x)
        states = states[-1]
        return self.output(states)

In [None]:
# [THIS IS READ-ONLY]
class MyLightningLSTM(MyLSTM, MyLightning):
    pass

游닉 Instruction

- You are encouraged to play with the parameters:

> - `d_emb`
> - `d_state`
> - `epochs`

游닉 Note:

- For `d_emb=8, d_state=16`, it takes 30 seconds per epoch.

In [None]:
# [YOUR WORK HERE]
# @workUnit

seed_everything(0)

train(
    name = 'lstm',
    model = MyLightningLSTM(d_emb=8, d_state=16),
    epochs = 5,
    debug = False,
)

## 11. Performance comparison

- Lightning logs the performance metrics in `./lightning_logs/{name}/{version}/metrics.csv`.
- We can load the metrics into pandas dataframes and plot the validation accuracy over runs.

In [None]:
# [THIS IS READ-ONLY]
perf_rnn = pd.read_csv('./lightning_logs/rnn/version_0/metrics.csv')
perf_lstm = pd.read_csv('./lightning_logs/lstm/version_0/metrics.csv')
val_acc = pd.concat([perf_rnn.val_acc.dropna(), perf_lstm.val_acc.dropna()], axis=1)
val_acc.columns = ['rnn', 'lstm']
val_acc

In [None]:
# [THIS IS READ-ONLY]
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(val_acc.index, val_acc.rnn, '--+', val_acc.index, val_acc.lstm, '-o')
plt.ylim(0, 1)
plt.title('Validation accuracy')
plt.legend(['RNN', 'LSTM']);

In [None]:
# [THIS IS READ-ONLY]
loss = pd.concat([perf_rnn.loss.dropna(), perf_lstm.loss.dropna()], axis=1)
loss.columns = ['rnn', 'lstm']

plt.figure(figsize=(10, 6))
plt.plot(loss.index, loss.rnn, '--+', loss.index, loss.lstm, '-o')
plt.title('Training loss')
plt.legend(['RNN', 'LSTM']);