# Extractive Question Answering with DistilBERT Transformer model

## Introduction

### Question Answering

Question Answering (QA) is one of the central problems in Natural Language Processing (NLP). QA tasks return an answer given a question. Virtual assistants like Alexa, Siri or Google Assistant are examples of QA systems. There are two common types of QA tasks:

- _Extractive_: extract the answer, as _spans of text_, from the given context.
- _Abstractive_: generate an answer from the context that correctly answers the question.

In this notebook we will:

- Fine-tune the DistilBERT Transformer model on the SQuAD dataset for extractive question answering.
- Use this fine-tuned model for inference.

### The DistilBERT model

The DistilBERT model was proposed in the blog post [Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT](https://medium.com/huggingface/distilbert-8cf3380435b5), and the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/papers/1910.01108). DistilBERT is a small, fast, cheap and light Transformer model trained by distilling BERT base. It has 40% less parameters than bert-base-uncased, runs 60% faster while preserving over 95% of BERT’s performances as measured on the GLUE language understanding benchmark.

Some notes on the model:
1. DistilBERT doesn’t have `token_type_ids`, you don’t need to indicate which token belongs to which segment. Just separate your segments with the separation token `tokenizer.sep_token` (or [SEP]).
2. DistilBERT doesn’t have options to select the input positions (`position_ids` input).
3. DistilBERT is basically the same as BERT but smaller. It is trained by distillation of the pretrained BERT model, i.e. it’s been trained to predict the same probabilities as the larger model. The actual objective is a combination of: (a) finding the same probabilities as the teacher model, (b) predicting the masked tokens correctly (but no next-sentence objective), (c) a cosine similarity between the hidden states of the student and the teacher model
  
Alternative models that can be used in this notebook instead of DistilBERT include:

[ALBERT](https://huggingface.co/docs/transformers/model_doc/albert), [BART](https://huggingface.co/docs/transformers/model_doc/bart), [BERT](https://huggingface.co/docs/transformers/model_doc/bert), [BigBird](https://huggingface.co/docs/transformers/model_doc/big_bird), [BigBird-Pegasus](https://huggingface.co/docs/transformers/model_doc/bigbird_pegasus), [BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom), [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert), [CANINE](https://huggingface.co/docs/transformers/model_doc/canine), [ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert), [Data2VecText](https://huggingface.co/docs/transformers/model_doc/data2vec-text), [DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta), [DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2), [ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra), [ERNIE](https://huggingface.co/docs/transformers/model_doc/ernie), [ErnieM](https://huggingface.co/docs/transformers/model_doc/ernie_m), [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon), [FlauBERT](https://huggingface.co/docs/transformers/model_doc/flaubert), [FNet](https://huggingface.co/docs/transformers/model_doc/fnet), [Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel), [OpenAI GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2), [GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo), [GPT NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox), [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj), [I-BERT](https://huggingface.co/docs/transformers/model_doc/ibert), [LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2), [LayoutLMv3](https://huggingface.co/docs/transformers/model_doc/layoutlmv3), [LED](https://huggingface.co/docs/transformers/model_doc/led), [LiLT](https://huggingface.co/docs/transformers/model_doc/lilt), [Longformer](https://huggingface.co/docs/transformers/model_doc/longformer), [LUKE](https://huggingface.co/docs/transformers/model_doc/luke), [LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert), [MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm), [mBART](https://huggingface.co/docs/transformers/model_doc/mbart), [MEGA](https://huggingface.co/docs/transformers/model_doc/mega), [Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert), [MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert), [MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet), [MPT](https://huggingface.co/docs/transformers/model_doc/mpt), [MRA](https://huggingface.co/docs/transformers/model_doc/mra), [MT5](https://huggingface.co/docs/transformers/model_doc/mt5), [MVP](https://huggingface.co/docs/transformers/model_doc/mvp), [Nezha](https://huggingface.co/docs/transformers/model_doc/nezha), [Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer), [OPT](https://huggingface.co/docs/transformers/model_doc/opt), [QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert), [Reformer](https://huggingface.co/docs/transformers/model_doc/roformer), [RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert), [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta), [RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm), [RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert), [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer), [Splinter](https://huggingface.co/docs/transformers/model_doc/splinter), [SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert), [T5](https://huggingface.co/docs/transformers/model_doc/t5), [UMT5](https://huggingface.co/docs/transformers/model_doc/umt5), [XLM](https://huggingface.co/docs/transformers/model_doc/xlm), [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta), [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl), [XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet), [X-MOD](https://huggingface.co/docs/transformers/model_doc/xmod), [YOSO](https://huggingface.co/docs/transformers/model_doc/yoso).

### The SQuAD Dataset

In this notebook we will fine-tune the DistilBERT model on the extractive QA task using the Hugging Face Transformers library and the [SQuAD dataset](https://huggingface.co/datasets/squad). 
The Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.

### Credits

This notebook is based on a [Hugging Face tutorial on Question Answering](https://huggingface.co/docs/transformers/tasks/question_answering).

## Setup

### Installing the requirements

In [1]:
%pip install --upgrade transformers datasets evaluate accelerate

Collecting transformers
  Downloading transformers-4.33.3-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m78.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
Collecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m59.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.23.0-py3-none-any.whl (258 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.1/258.1 kB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m
Collecting safetensors>=0.3.1
  Downloading safetensors-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### Define variables

In [2]:
# Datasets
FULL_DATASET = "squad"

# Pre-trained models
MODEL_CHECKPOINT = "bert-base-cased"
TRAINED_CHECKPOINT = "distilbert-base-cased-distilled-squad"

# Local directory where to save the finetuned models
MODEL_PATH = "saved_models"

# Repository name for saving models to the Hugging Face Hub
REPO_NAME = "Extr-QA-DistilBERT"

# Aux variables
DS_SAMPLE_SIZE = 2000 # Since training the full SQuAD model can take a few hours on an entry-level GPU, we'll get a subset of it
TRAIN_TEST_SPLIT = 0.2 # The percentage of the dataset we will split as train and test
TOKEN_MAX_LENGTH = 384 # Maximum length of tokens
TOKEN_STRIDE = 128 # Tokenizer sliding window length
TRAIN_BATCH_SIZE = 8
EVAL_BATCH_SIZE = 8
NUM_TRAIN_EPOCHS = 4
LR = 2e-5 # Learning Rate
WD = 0.01 # Weight Decay

# Disable W&B logging
import os
os.environ["WANDB_DISABLED"] = "true"

### Load the dataset

We will start by loading the full SQuAD dataset from the Hugging Face Datasets library, so we can explore its structure.

In [3]:
from datasets import load_dataset

squad_full = load_dataset(FULL_DATASET)

Downloading builder script:   0%|          | 0.00/5.27k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.36k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.67k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/8.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

We can take a look at an example in the dataset:

In [4]:
print("Context: ", squad_full["train"][0]["context"])
print("Question: ", squad_full["train"][0]["question"])
print("Answer: ", squad_full["train"][0]["answers"])

Context:  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question:  To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer:  {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}


Some of the more relevant fields are:
- _context_: background information from which the model needs to extract the answer.
- _question_: the question a model should answer.
- _answers_: the starting location of the answer token and the answer text.

The _context_ and _question_ fields are very straightforward to use. The _answers_ field is a bit trickier as it comports a dictionary with two fields that are both lists. This is the format that will be expected by the squad metric during evaluation; if you are using your own data, you don’t necessarily need to worry about putting the answers in the same format. 
The _text_ field is rather obvious, and the _answer_start_ field contains the starting character index of each answer in the context.

During training, there is only one possible answer. We can double-check this by using the `Dataset.filter()` method:

In [5]:
squad_full["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

Filter:   0%|          | 0/87599 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

For evaluation, however, there are several possible answers for each sample, which may be the same or different:

In [6]:
squad_full["validation"].filter(lambda x: len(x["answers"]["text"]) != 1)

Filter:   0%|          | 0/10570 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 10567
})

If we take a look at the sample at index 2, for instance:

In [7]:
print(squad_full["validation"][2]["context"])
print(squad_full["validation"][2]["question"])
print(squad_full["validation"][2]["answers"])

Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
Where did Super Bowl 50 take place?
{'text': ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."], 'answer_start': [403, 355, 355]}


we can see that the answer can indeed be one of three possibilities.

## Data pre-processing

First, we load a DistilBERT tokenizer to process the question and context fields:

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Let's take a look at how the tokenizer works:

In [9]:
context = squad_full["train"][0]["context"]
question = squad_full["train"][0]["question"]
inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

The tokenizer inserts special tokens to form a sentence with the following structure: [CLS] question [SEP] context [SEP]

The labels will then be the index of the tokens starting and ending the answer, and the model will be tasked to predicted one start and end logit per token in the input.

Some of the examples in the dataset have very long contexts that will exceed the maximum length we set (which is 384 in this example). We will deal with long contexts by creating several training features from one sample of our dataset, with a sliding window between them.

To see how this works using the current example, we can limit the length to 100 and use a sliding window of 50 tokens. As a reminder, we use:

- `max_length` to set the maximum length (here 100)
- `truncation="only_second"` to truncate the context (which is in the second position) when the question with its context is too long
- `stride` to set the number of overlapping tokens between two successive chunks (here 50)
- `return_overflowing_tokens=True` to let the tokenizer know we want the overflowing tokens

In [10]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
    print(tokenizer.decode(ids))

[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the B

As we can see, our example has been in split into four inputs, each of them containing the question and some part of the context. Note that the answer to the question (“Bernadette Soubirous”) only appears in the third and last inputs, so by dealing with long contexts in this way we will create some training examples where the answer is not included in the context. For those examples, the labels will be start_position = end_position = 0 (so we predict the [CLS] token). We will also set those labels in the unfortunate case where the answer has been truncated so that we only have the start (or end) of it. For the examples where the answer is fully in the context, the labels will be the index of the token where the answer starts and the index of the token where the answer ends.

The dataset provides us with the start character of the answer in the context, and by adding the length of the answer, we can find the end character in the context. To map those to token indices, we will need to use offset mappings. We can have our tokenizer return these by passing along `return_offsets_mapping=True`:

In [11]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

As we can see, we get back the usual input IDs, token type IDs, and attention mask, as well as the offset mapping we required and an extra key, `overflow_to_sample_mapping`. The corresponding value will be of use to us when we tokenize several texts at the same time (which we should do to benefit from the fact that our tokenizer is backed by Rust). Since one sample can give several features, it maps each feature to the example it originated from. Because here we only tokenized one example, we get a list of 0s:

In [12]:
inputs["overflow_to_sample_mapping"]

[0, 0, 0, 0]

But if we tokenize more examples, this will become more useful:

In [13]:
inputs = tokenizer(
    squad_full["train"][2:6]["question"],
    squad_full["train"][2:6]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

print(f"The 4 examples gave {len(inputs['input_ids'])} features.")
print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

The 4 examples gave 19 features.
Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3].


To determine which of these is the case and, if relevant, the positions of the tokens, we first find the indices that start and end the context in the input IDs. We could use the token type IDs to do this, but since those do not necessarily exist for all models (DistilBERT does not require them, for instance), we’ll instead use the `sequence_ids()` method of the BatchEncoding our tokenizer returns.

Once we have those token indices, we look at the corresponding offsets, which are tuples of two integers representing the span of characters inside the original context. We can thus detect if the chunk of the context in this feature starts after the answer or ends before the answer begins (in which case the label is (0, 0)). If that’s not the case, we loop to find the first and last token of the answer:

In [14]:
answers = squad_full["train"][2:6]["answers"]
start_positions = []
end_positions = []

for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

start_positions, end_positions

([83, 51, 19, 0, 0, 64, 27, 0, 34, 0, 0, 0, 67, 34, 0, 0, 0, 0, 0],
 [85, 53, 21, 0, 0, 70, 33, 0, 40, 0, 0, 0, 68, 35, 0, 0, 0, 0, 0])

Let’s take a look at a few results to verify that our approach is correct. For the first feature we find (83, 85) as labels, so let’s compare the theoretical answer with the decoded span of tokens from 83 to 85 (inclusive):

In [15]:
idx = 0
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]
start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1])
print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

Theoretical answer: the Main Building, labels give: the Main Building


So that’s a match! Now let’s check index 4, where we set the labels to (0, 0), which means the answer is not in the context chunk of that feature:

In [16]:
idx = 4
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]
decoded_example = tokenizer.decode(inputs["input_ids"][idx])
print(f"Theoretical answer: {answer}, decoded example: {decoded_example}")

Theoretical answer: a Marian place of prayer and reflection, decoded example: [CLS] What is the Grotto at Notre Dame? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grot [SEP]


Indeed, we don’t see the answer inside the context.

Now that we have seen step by step how to preprocess our training data, we can group it in a function we will apply on the whole training dataset. We’ll pad every feature to the maximum length we set, as most of the contexts will be long (and the corresponding samples will be split into several features), so there is no real benefit to applying dynamic padding here. To recap, our `preprocess_function()` will:
 
1. Deal with longer sequences in the datasaet (those that have a very long context that exceeds the maximum input length of the model), by truncating only the context, by setting `truncation="only_second"`.
1. Map the start and end positions of the answer to the original context by setting `return_offset_mapping=True`.
1. Find the start and end tokens of the answer, by using the `sequence_ids` method to find which part of the offset corresponds to the question and which corresponds to the context.

With this, our `preprocess_function()` will truncate and map the start and end tokens of the answer to the context:

In [17]:
def preprocess_train_set(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=TOKEN_MAX_LENGTH,
        truncation="only_second",
        stride=TOKEN_STRIDE,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

Note that we defined two constants to determine the maximum length used as well as the length of the sliding window, and that we added a tiny bit of cleanup before tokenizing: some of the questions in the SQuAD dataset have extra spaces at the beginning and the end that don’t add anything (and take up space when being tokenized if you use a model like RoBERTa), so we removed those extra spaces.

We now select a random sample from the full dataset, that we'll use for fine-tuning the model.

In [18]:
squad_train_sample = squad_full["train"].shuffle(seed=42).select(range(DS_SAMPLE_SIZE))

To apply the preprocessing function to the subset of the full dataset that we will use for training, we use the `Dataset.map()` function. We can speed up the `map` function by setting `batched=True` to process multiple elements of the dataset at once. We can also remove any columns we don’t need:

In [19]:
train_dataset_sample = squad_train_sample.map(preprocess_train_set, batched=True, remove_columns=squad_train_sample.column_names)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [20]:
len(squad_train_sample), len(train_dataset_sample)

(2000, 2026)

As we can see, the preprocessing added 26 features (for a sample dataset of size 2,000). Our training set is now ready to be used — let’s now move into the preprocessing of the validation set.

Preprocessing the validation data will be slightly easier as we don’t need to generate labels (unless we want to compute a validation loss, but that number won’t really help us understand how good the model is). The real joy will be to interpret the predictions of the model into spans of the original context. For this, we will just need to store both the offset mappings and some way to match each created feature to the original example it comes from. Since there is an ID column in the original dataset, we’ll use that ID.

The only thing we’ll add here is a tiny bit of cleanup of the offset mappings. They will contain offsets for the question and the context, but once we’re in the post-processing stage we won’t have any way to know which part of the input IDs corresponded to the context and which part was the question (the `sequence_ids()` method we used is available for the output of the tokenizer only). So, we’ll set the offsets corresponding to the question to `None`:

In [21]:
def preprocess_validation_set(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=TOKEN_MAX_LENGTH,
        truncation="only_second",
        stride=TOKEN_STRIDE,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

We can apply this function on a sample of the validation dataset like before:

In [22]:
squad_validation_sample = squad_full["validation"].shuffle(seed=42).select(range(DS_SAMPLE_SIZE))
validation_dataset_sample = squad_validation_sample.map(preprocess_validation_set, batched=True, remove_columns=squad_validation_sample.column_names,)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [23]:
len(squad_validation_sample), len(validation_dataset_sample)

(2000, 2051)

In this case we’ve added 51 samples, so it appears the contexts in the validation dataset are a bit longer.

Now that we have preprocessed all the data, we can get to training the model.

## Fine-tune the model

To fine-tune the model we have to write the `compute_metrics()` function. Since we padded all the samples to the maximum length we set, there is no data collator to define, so this metric computation is really the only thing we have to worry about. The difficult part will be to post-process the model predictions into spans of text in the original examples; once we have done that, the metric from the Hugging Face Datasets library will do most of the work for us.


### Post-processing

As we mentioned above, the model will output logits for the start and end positions of the answer in the _input IDs_. The post-processing step requires that we:

- Mask the start and end logits corresponding to tokens outside of the context.
- Then convert the start and end logits into probabilities using a softmax.
- Attribute a score to each (start_token, end_token) pair by taking the product of the corresponding two probabilities.
- Look for the pair with the maximum score that yielded a valid answer (e.g., a start_token lower than end_token).

Here we will change this process slightly because we don’t need to compute actual scores (just the predicted answer). This means we can skip the softmax step. To go faster, we also won’t score all the possible (start_token, end_token) pairs, but only the ones corresponding to the highest n_best logits (with n_best=20). Since we will skip the softmax, those scores will be logit scores, and will be obtained by taking the sum of the start and end logits (instead of the product, because of the rule _log⁡(ab)=log⁡(a)+log⁡(b)log(ab)=log(a)+log(b))_.

To demonstrate all of this, we will need some kind of predictions. Since we have not trained our model yet, we are going to use the default model for the QA pipeline to generate some predictions on a small part of the validation set. We can use the same processing function as before; because it relies on the global constant tokenizer, we just have to change that object to the tokenizer of the model we want to use temporarily:

In [24]:
small_eval_set = squad_validation_sample.select(range(100))

tokenizer = AutoTokenizer.from_pretrained(TRAINED_CHECKPOINT)
eval_set = small_eval_set.map(
    preprocess_validation_set,
    batched=True,
    remove_columns=squad_full["validation"].column_names,
)

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/473 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Now that the preprocessing is done, we change the tokenizer back to the one we originally picked:

In [25]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

We then remove the columns of our eval_set that are not expected by the model, build a batch with all of that small validation set, and pass it through the model. If a GPU is available, we use it to go faster:

In [26]:
import torch
from transformers import AutoModelForQuestionAnswering

eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(TRAINED_CHECKPOINT).to(device)

with torch.no_grad():
    outputs = trained_model(**batch)

Downloading model.safetensors:   0%|          | 0.00/261M [00:00<?, ?B/s]

Since the `Trainer` will give us predictions as NumPy arrays, we grab the start and end logits and convert them to that format:

In [27]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

Now, we need to find the predicted answer for each example in our small_eval_set. One example may have been split into several features in eval_set, so the first step is to map each example in small_eval_set to the corresponding features in eval_set:

In [28]:
import collections

example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(eval_set):
    example_to_features[feature["example_id"]].append(idx)

With this in hand, we can really get to work by looping through all the examples and, for each example, through all the associated features. As we said before, we’ll look at the logit scores for the n_best start logits and end logits, excluding positions that give:
- An answer that wouldn’t be inside the context
- An answer with negative length
- An answer that is too long (we limit the possibilities at max_answer_length=30)

Once we have all the scored possible answers for one example, we just pick the one with the best logit score:

In [29]:
import numpy as np

n_best = 20
max_answer_length = 30
predicted_answers = []

for example in small_eval_set:
    example_id = example["id"]
    context = example["context"]
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = eval_set["offset_mapping"][feature_index]

        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Skip answers that are not fully in the context
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                # Skip answers with a length that is either < 0 or > max_answer_length.
                if (
                    end_index < start_index
                    or end_index - start_index + 1 > max_answer_length
                ):
                    continue

                answers.append(
                    {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                )

    best_answer = max(answers, key=lambda x: x["logit_score"])
    predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})

The final format of the predicted answers is the one that will be expected by the metric we will use. As usual, we can load it with the help of the HuggingFace Evaluate library:

In [30]:
import evaluate

metric = evaluate.load(FULL_DATASET)

Downloading builder script:   0%|          | 0.00/4.53k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.32k [00:00<?, ?B/s]

This metric expects the predicted answers in the format we saw above (a list of dictionaries with one key for the ID of the example and one key for the predicted text) and the theoretical answers in the format below (a list of dictionaries with one key for the ID of the example and one key for the possible answers):

In [31]:
theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in small_eval_set]

We can now check that we get sensible results by looking at the first element of both lists:

In [32]:
print(predicted_answers[0])
print(theoretical_answers[0])

{'id': '572759665951b619008f8884', 'prediction_text': '1852'}
{'id': '572759665951b619008f8884', 'answers': {'text': ['1852', '1852', '1852'], 'answer_start': [158, 158, 158]}}


Not bad! Now let’s have a look at the score the metric gives us:

In [33]:
metric.compute(predictions=predicted_answers, references=theoretical_answers)

{'exact_match': 84.0, 'f1': 87.8248120300752}

Again, that’s rather good considering that according to its paper DistilBERT fine-tuned on SQuAD obtains 79.1 and 86.9 for those scores on the whole dataset.

Now let’s put everything we just did in a `compute_metrics()` function that we will use in the `Trainer`. Normally, that `compute_metrics()` function only receives a tuple `eval_preds` with logits and labels. Here we will need a bit more, as we have to look in the dataset of features for the offset and in the dataset of examples for the original contexts, so we won’t be able to use this function to get regular evaluation results during training. We will only use it at the end of training to check the results.

The `compute_metrics()` function groups the same steps as before; we just add a small check in case we don’t come up with any valid answers (in which case we predict an empty string).

In [34]:
from tqdm.auto import tqdm

def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

We can now check whether it works on our predictions:

In [35]:
compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

  0%|          | 0/100 [00:00<?, ?it/s]

{'exact_match': 84.0, 'f1': 87.8248120300752}

Looking good! Now let’s use this to fine-tune our model.

### Fine-tune the model with the Trainer API

We are now ready to train our model. Let’s create it first, using the `AutoModelForQuestionAnswering` class:

In [36]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model = AutoModelForQuestionAnswering.from_pretrained(MODEL_CHECKPOINT)

Downloading model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


As usual, we get a warning that some weights are not used (the ones from the pretraining head) and some others are initialized randomly (the ones for the question answering head). You should be used to this by now, but that means this model is not ready to be used just yet and needs fine-tuning, which is exactly what we're about to do now.

To be able to push our model to the Hugging Face Hub, we’ll need to log in to Hugging Face. If you’re running this code in a notebook, you can do so with the following utility function, which displays a widget where you can enter your login credentials:

In [37]:
#from huggingface_hub import notebook_login

#notebook_login()

Once this is done, we can define our `TrainingArguments`. As we said when we defined our function to compute the metric, we won’t be able to have a regular evaluation loop because of the signature of the `compute_metrics()` function. We could write our own subclass of `Trainer` to do this, but that’s a bit too long for this introductory notebook. Instead, we will only evaluate the model at the end of training here and show how to do a regular evaluation in the “A custom training loop” section below.

This is really where the `Trainer` API shows its limits and the Hugging Face `Accelerate` library shines: customizing the class to a specific use case can be painful, but tweaking a fully exposed training loop is easy.

Let’s take a look at our `TrainingArguments`:

From here, we have to:
1. Define our training hyperparameters in `TrainingArguments`. The only required parameter is `output_dir` which specifies where to save our model. We could also prepare to later push this model to the Hugging Face Hub by setting `push_to_hub=True` (but you need to be signed in to Hugging Face to upload your model).
2. Pass the training arguments to `Trainer` along with the model, dataset, tokenizer, and data collator.
3. Call `train()` to fine-tune our model.

In [38]:
training_args = TrainingArguments(
    output_dir=MODEL_PATH,
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=LR,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    weight_decay=WD,
    fp16=True,
    push_to_hub=False
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


We set some hyperparameters (like the learning rate, the number of epochs we train for, and some weight decay) and indicate that we want to save the model at the end of every epoch, skip evaluation, and upload our results to the Model Hub. We also enable mixed-precision training with `fp16=True`, as it can speed up the training nicely on a recent GPU.

By default, the repository used will be in your namespace and named after the output directory you set. We can override this by passing a `hub_model_id`. If the output directory you are using exists, it needs to be a local clone of the repository you want to push to (so set a new name if you get an error when defining your `Trainer`).

Finally, we just pass everything to the `Trainer` class and launch the training:

In [39]:
trainer_small = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_sample,
    eval_dataset=validation_dataset_sample,
    tokenizer=tokenizer,
)

In [40]:
trainer_small.train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
500,2.4089
1000,0.7983


TrainOutput(global_step=1016, training_loss=1.5868787681023906, metrics={'train_runtime': 163.4289, 'train_samples_per_second': 49.587, 'train_steps_per_second': 6.217, 'total_flos': 1588161687441408.0, 'train_loss': 1.5868787681023906, 'epoch': 4.0})

## Model evaluation

Once the training is complete, we can finally evaluate our model (and pray we didn’t spend all that compute time on nothing). The `predict()` method of the `Trainer` will return a tuple where the first elements will be the predictions of the model (here a pair with the start and end logits). We send this to our `compute_metrics()` function:


In [42]:
predictions, _, _ = trainer_small.predict(validation_dataset_sample)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, validation_dataset_sample, squad_validation_sample)

  0%|          | 0/2000 [00:00<?, ?it/s]

{'exact_match': 56.1, 'f1': 68.15696025705903}

Now that we have verified that the fine-tuning process works, let's train the full model.

In [43]:
train_dataset = squad_full["train"].map(preprocess_train_set, batched=True, remove_columns=squad_full["train"].column_names)
eval_dataset = squad_full["validation"].map(preprocess_validation_set, batched=True, remove_columns=squad_full["validation"].column_names)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [44]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)
trainer.train()

Step,Training Loss
500,1.6544
1000,1.4272
1500,1.3783
2000,1.324
2500,1.2637
3000,1.2255
3500,1.211
4000,1.1977
4500,1.1879
5000,1.1487


TrainOutput(global_step=44368, training_loss=0.6884294775389525, metrics={'train_runtime': 6737.6321, 'train_samples_per_second': 52.677, 'train_steps_per_second': 6.585, 'total_flos': 6.955379978528563e+16, 'train_loss': 0.6884294775389525, 'epoch': 4.0})

Let's see how good our fine-tuned model is:

In [45]:
predictions, _, _ = trainer.predict(eval_dataset)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, eval_dataset, squad_full["validation"])

  0%|          | 0/10570 [00:00<?, ?it/s]

{'exact_match': 80.1608325449385, 'f1': 88.0252083889449}

Great! As a comparison, the baseline scores reported in the BERT article for this model are 80.8 and 88.5, so we’re right where we should be.