# Lab 5, part 2: Seq2Seq Transformers — Summarization

In [None]:
! pip install datasets transformers rouge-score nltk torch numpy matplotlib      # << Uncomment to install packages
import transformers
print(transformers.__version__)     # Should be >= 4.11.0
import torch
import nltk
nltk.download('punkt')

In this notebook, we will see how to fine-tune one of the [HuggingFace Transformers](https://github.com/huggingface/transformers) model for a summarization task. We will use the [XSum dataset](https://arxiv.org/pdf/1808.08745.pdf) (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries.

![Widget inference on a summarization task](https://raw.githubusercontent.com/huggingface/notebooks/main/examples/images/summarization.png)

We will see how to easily load the dataset for this task using HuggingFace Datasets and how to fine-tune a model on it using the `Trainer` API.

This tutorial is draws from the Huggingface Summarization Tutorial.

## Loading the dataset

We will use the [HuggingFace Datasets](https://github.com/huggingface/datasets) library to download the data we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset`. The Datasets library is a great resource for conveniently working with common datasets.

In [None]:
from datasets import list_datasets, load_dataset

# Can see all possible datasets as follows
all_datasets = list_datasets()
print(all_datasets[:10], len(all_datasets))

raw_datasets = load_dataset("xsum")

# Each dataset has a train, val and test split
print(raw_datasets.keys(), '\n', raw_datasets['train'])

# An XSum sample looks as follows
print(raw_datasets["train"][0])

### EDA

To get a sense of what the data looks like, we will write a function to display some random elements. We will also perform some basic EDA (exploratory data analysis) by computing the mean and standard deviation token counts for the source and target documents.

**Q1:** Complete `show_random_elements` below. This should display a table containing `num_examples` samples from the specified `dataset`, with fields `Souce Document`, `Target Document` and `Document ID`

**Q2:** EDA: 
 - Tokenize the document (the best option would be to use the tokenizer but here we will just split on spaces) and print the mean count and standard deviation for source and target documents for each dataset
 - Plot histograms of the source and target token counts


In [None]:
import datasets
import random
import pandas as pd
import numpy as np
from IPython.display import display, HTML
import matplotlib.pyplot as plt

def show_random_elements(dataset, num_examples=5):
    ## YOUR CODE HERE ##

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))


def get_token_counts(dataset):
    ## YOUR CODE HERE ##
    return src_counts, tgt_counts


def token_counts_summary(raw_dataset):
    for name,dataset in raw_dataset.items():
        src_counts, tgt_counts = get_token_counts(dataset)
        ## YOUR CODE HERE ##


def plot_token_counts(dataset):
    src_counts, tgt_counts = get_token_counts(dataset)
    ## YOUR CODE HERE ##


show_random_elements(raw_datasets["val"])
plot_token_counts(raw_datasets["validation"])

### Loading the metric

To evaluate our model's performance, we will use the ROUGE summarization metric. This is provided natively within the dataset library and can be loaded similarly to how we loaded the dataset above. **Note** this is a big advantage in practise as i) metrics can be fiddly to implement manually and ii) difficult to align completely across implmentations as decisions like lemmatization, tokenization and punctation-handling can create large discrepancies in scores.

You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

In [None]:
from datasets import load_metric
metric = load_metric("rouge")

# help(metric)      # << Uncomment to see more about the ROUGE eval metric

# Try it out below
# fake_preds = ## YOUR CODE HERE
# fake_labels = ## YOUR CODE HERE
## COMPUTE METRIC HERE ## 


## Preprocessing the data

We will proprocess the data using the Huggingface `Tokenizer`. This tokenizes and indexes the inputs and put it in a format the model expects. It also generate the other inputs that the model requires.

**Q3:** Each Huggingface model has a paired tokenizer. Why is it important to use the appropriate tokenizer?

In [None]:
from transformers import AutoTokenizer

model_checkpoint = "t5-small"       # Can find more options at https://huggingface.co/models?sort=downloads&search=t5    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

### Using the tokenizer

Try out the tokenizer in the cell below.

**Q4:** The tokenizer output has different fields for different tokenizers. 

**Q4.i:** Confirm this by writing code below to instantiate a BERT tokenizer and comparing the outputs across the two tokenizers of when tokenizing a string. You may find [this page](https://huggingface.co/models) helpful.

**Q4.ii:** Why does the BERT tokenizer have an additional field to the T5 tokenizer?

In [None]:
# Test the tokenizer below
print(tokenizer("We love NLP!"))
# print(tokenizer(## YOUR CODE HERE))

## CODE FOR 4.i) HERE ##


### Additional tokenizer functionalities:

- The tokenizer can be fed a list of strings
- When tokenizing the target documents, using `tokenizer.as_target_tokenizer()` to ensure the target receives the appropriate special tokens (although here the source and target are tokenized identically)
- We can convert back from ids to tokens by using `tokenizer.convert_ids_to_tokens()`

**Q5:** The output of `tokenizer.convert_ids_to_tokens()` is different from the initial string. Why is this?


In [None]:
print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))

with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))

print(tokenizer.convert_ids_to_tokens(tokenizer("We love NLP!")['input_ids']))

T5 was trained within a multitask framework such that it can perform multiple tasks out-of-the-box. We prefix the inputs with "summarize: " to prompt the model to deliver the correct outputs.

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that any input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

Complete the `preprocess_function` below to tokenize the text (**hint**: don't forget prefix or the context manager ;) )

In [None]:
prefix = "summarize: " if model_checkpoint.startswith("t5-") else ""

max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    ## YOUR CODE HERE ##

    return model_inputs # should be a dict containing the indices for the source and target docs.

In [None]:
# Test it using the below call
preprocess_function(raw_datasets['train'][:2])

This function can be applied to all our datasets using the `map` method of our `dataset` object. The results are automatically cached to avoid spending time on this step the next time you run your notebook. 

In [None]:
tokenized_datasets = raw_datasets.map(
    preprocess_function, 
    batched=True            # This employs multithreading to speed up tokenization
)

## Using T5 out-of-the-box

The image below shows the idea behind T5.

![T5 img](figs/t5.png)

In case you are curoius, the following [link](https://paperswithcode.com/method/t5) shows more information on T5: 

Here, we will experiment with some of T5's capabilities without pre-training. First, we will download the model using `AutoModelForSeq2SeqLM` class using the `from_pretrained` method (this caches the model for us).

To illustrate the impact of the prompt, we will try T5 using two different prompts using the same string. Note that we are using the small version without fine-tuning so the results may not be great.

**Note:** If GPUs are not available for you, `.cuda()` calls will throw errors. If this is the case we suggest you remove the corresponding calls and work on CPU. However, we do not recommend to run this lab session on CPUs because doing so takes more than 200 hours.

In [None]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model.cuda()

**Q6**. Complete the `tok_and_gen` function below. This requires the following steps:
- Tokenize the inputs
- Generate output ids (hint: using model.generate. Google this for more details)
- Convert the ids to string
- Print the string

In [None]:
input_str = 'Artificial general intelligence (AGI) is the hypothetical ability of an intelligent agent to understand or learn any intellectual task that a human being can.[1] It is a primary goal of some artificial intelligence research and a common topic in science fiction and futures studies.'

prefixes = ["translate English to German: ", "summarize: "]     # What happens if we make up a prompt? Why?

def tok_and_gen(model, tokenizer, input, prefix=''):
    ## YOUR CODE HERE ##

    print(f'Using prefix {prefix}: ', decoded)

for task_prefix in prefixes:
    tok_and_gen(model, tokenizer, task_prefix, input_str)    


## Fine-tuning the model

Now that our data is ready and we have played around with our model, we can fine-tune it. 

HuggingFace provides an API for training a seq2seq model: the `Seq2SeqTrainer`. To instantiate this, we will need to define three more things. The most important is the [`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments), which is a class that contains all the attributes to customize the training. It requires a folder name for saving checkpoints of the model, and all other arguments are optional:

**Note:** Set `fp16` argument to `False` if using CPU computation.

In [None]:
batch_size = 4
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",     # Ouptut folder
    # Eval strategy
    evaluation_strategy="steps",
    eval_steps=100,
    # Could alternatively be the following to eval every epoch:
    # evaluation_strategy="epoch",

    # LR. Should be small (<1e-4)
    learning_rate=2e-5,

    # Batch size during training and eval
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,

    # Limits the number of models saved during training. Important to prevent memory clogging up!
    save_total_limit=3,

    # Properly generate summaries during eval
    predict_with_generate=True,
    
    # Mixed precision training (speeds up training - see Nvidia-apex for more details)
    fp16=True,

    weight_decay=0.01,
    num_train_epochs=1,
)

Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels:

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

The last thing to define for our `Seq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:

In [None]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`. This is the general Huggingface API adapted for training seq2seq models.

**Q7:** Instantiate the trainer class and begin finetuning on the XSum dataset. You may find [this page](https://huggingface.co/docs/transformers/main_classes/trainer) helpful.

**Note:** training will likely take a while so you may want to end training once you are satisfied it is working correctly in order to progress to the next section.

In [None]:
## YOUR CODE HERE ##


## Complexity Analysis

Here we will analyze the computational complexity of the Transformer. One of the main drawbacks of this architecture is that the time and memory complexity scales poorly with respect to input length. This is particularly problematic for summarization as long document summarization can become prohibitively expensive.

**Q8:** Which component of the Transformer does the poor complexity problems mentioned above referred to? Explain why this is the case.

**Q9:** We will profile the time complexity of T5 for different input sequence lengths. Complete the function below which will plot the time of T5's forward pass. You can use the inbuilt pytorch profiler or a simpler method (e.g. `time.time`) as you prefer.
- We will first do this up to 512 tokens on a log base 2 scale (i.e. 1, 2, 4, 8, ... 512)
- Do this for the encoder (holding decoder input length fixed) and the decoder (holding encoder input length fixed)

In [None]:
from time import time
import matplotlib.pyplot as plt

def profile_seq2seq(model, max_len, batch_size=1):
    # Create dummy input tensors
    base_encoder_input = torch.tensor([[10]]).cuda()
    base_decoder_input = torch.tensor([[10]]).cuda()
    # Make a larger batch to make trends more evident
    base_encoder_input = base_encoder_input.repeat(batch_size,1)
    base_decoder_input = base_decoder_input.repeat(batch_size,1)
    
    ## YOUR CODE HERE ##

profile_seq2seq(model, 512, 1)

The above plots may or may not show an upward trajectory. Now try again with a larger batch size (e.g. 8). Now they should show an upward slope as the forward pass becomes more expensive relative to other overheads.