## CS310 Natural Language Processing
## Lab 13: Explore Question-Answering Models and Datasets

In this lab, we will practice with running pretrained models on question-answering tasks. The we demonstrate with is `distilbert-base-uncased`, which is a smaller version of BERT.

We will use the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) datast provided in the [Datasets](https://github.com/huggingface/datasets) library. Make sure to install the library:

```bash
pip install datasets
```

In [56]:
from pprint import pprint

### T1. Explore the SQuAD dataset

First, let's load the SQuAD dataset

In [57]:
from datasets import load_dataset, load_metric

squad_dataset = load_dataset('./squad')

The `squad_dataset` object is a `DefaultDict` that contains keys for the train and validation splits.

In [58]:
squad_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

To access a data instance, you can specify the split and index:

In [59]:
squad_dataset['train'][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

We can see that teh answer is indicated by its span start index (at character `515`) in the passage text. 

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset

In [60]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [61]:
show_random_elements(squad_dataset["train"], num_examples=3)

Unnamed: 0,id,title,context,question,answers
0,572f8123a23a5019007fc6ad,Hyderabad,"The jurisdictions of the city's administrative agencies are, in ascending order of size: the Hyderabad Police area, Hyderabad district, the GHMC area (""Hyderabad city"") and the area under the Hyderabad Metropolitan Development Authority (HMDA). The HMDA is an apolitical urban planning agency that covers the GHMC and its suburbs, extending to 54 mandals in five districts encircling the city. It coordinates the development activities of GHMC and suburban municipalities and manages the administration of bodies such as the Hyderabad Metropolitan Water Supply and Sewerage Board (HMWSSB).",Which Hyderabad agency is responsible for the largest area?,"{'text': ['Hyderabad Police'], 'answer_start': [93]}"
1,5706b0a62eaba6190074ac3b,House_music,"The house scene in cities such as Birmingham, Leeds, Sheffield and London were also provided with many underground Pirate Radio stations and DJs alike which helped bolster an already contagious, but otherwise ignored by the mainstream, music genre. The earliest and influential UK house and techno record labels such as Warp Records and Network Records (otherwise known as Kool Kat records) helped introduce American and later Italian dance music to Britain as well as promoting select UK dance music acts.",what helped to bolster house music in the uk?,"{'text': ['underground Pirate Radio stations and DJs'], 'answer_start': [103]}"
2,57109ebeb654c5140001f9db,Age_of_Enlightenment,"The most influential publication of the Enlightenment was the Encyclopédie, compiled by Denis Diderot and (until 1759) by Jean le Rond d'Alembert and a team of 150 scientists and philosophers. It was published between 1751 and 1772 in thirty-five volumes, and spread the ideas of the Enlightenment across Europe and beyond. Other landmark publications were the Dictionnaire philosophique (Philosophical Dictionary, 1764) and Letters on the English (1733) written by Voltaire; Rousseau's Discourse on Inequality (1754) and The Social Contract (1762); and Montesquieu's Spirit of the Laws (1748). The ideas of the Enlightenment played a major role in inspiring the French Revolution, which began in 1789. After the Revolution, the Enlightenment was followed by an opposing intellectual movement known as Romanticism.",What year did the French Revolution begin?,"{'text': ['1789'], 'answer_start': [697]}"


### T2. Preprocess the data

Before we feed the data to a model for fine-tuning, there is some preprocessing needed: 
- Tokenize the input text
- Put it in the format expected by the model
- Generate other inputs the model requires

To do all of this, we need to instantiate a tokenizer that is compatible with the model we want to use, i.e., `distilbert-base-uncased`.

In [62]:
from transformers import AutoTokenizer

model_checkpoint = "./distilbert-base-uncased" # If loaded locally, make sure you have the model downloaded first
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

You can directly call this tokenizer on two sentences (e.g., question and context):

In [63]:
tokenizer('Architecturally, the school has a Catholic character.', 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?')

{'input_ids': [101, 6549, 2135, 1010, 1996, 2082, 2038, 1037, 3234, 2839, 1012, 102, 2000, 3183, 2106, 1996, 6261, 2984, 9382, 3711, 1999, 8517, 1999, 10223, 26371, 2605, 1029, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

An important step in QA is to deal with very **long documents**. If longer than the maximum input size of model, then removing part of context might result in losing the answer.

To handle this, we will allow a long document to give several input *features*, each of length shorter than the maximum size. 

Also, in case the answer is split between two features, we allow some overlap between features, controlled by `doc_stride`.

In [64]:
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

Let's examine on one long example:

In [65]:
for i, example in enumerate(squad_dataset["train"]):
    if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
        break
example = squad_dataset["train"][i]

Without truncation, its length is:

In [66]:
len(tokenizer(example['question'], example['context'])['input_ids'])

396

If we truncate, the resulting length is:

In [67]:
len(tokenizer(example["question"], example["context"], max_length=max_length, truncation="only_second")["input_ids"])

384

Note that we never want to truncate the question, so we specify `truncation='only_second`. 

Now, we further tell the tokenizer to return the overlaping features, by setting `return_overflowing_tokens=True` and `stride=doc_stride`.

In [68]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    stride=doc_stride
)

print([len(x) for x in tokenized_example["input_ids"]])

[384, 157]


We can look at the two features decoded:

In [69]:
for x in tokenized_example["input_ids"][:2]:
    pprint(tokenizer.decode(x))

("[CLS] how many wins does the notre dame men's basketball team have? [SEP] "
 "the men's basketball team has over 1, 600 wins, one of only 12 schools who "
 'have reached that mark, and have appeared in 28 ncaa tournaments. former '
 'player austin carr holds the record for most points scored in a single game '
 'of the tournament with 61. although the team has never won the ncaa '
 'tournament, they were named by the helms athletic foundation as national '
 'champions twice. the team has orchestrated a number of upsets of number one '
 "ranked teams, the most notable of which was ending ucla's record 88 - game "
 'winning streak in 1974. the team has beaten an additional eight number - one '
 "teams, and those nine wins rank second, to ucla's 10, all - time in wins "
 'against the top team. the team plays in newly renovated purcell pavilion ( '
 'within the edmund p. joyce center ), which reopened for the beginning of the '
 '2009 – 2010 season. the team is coached by mike brey, who,

Now, we nned to find out in which of the two features the answer is, and where exactly it starts and ends.

Thankfully, the tokenizer can help us by returning the `offset_mapping` that gives the start and end character of each token:

In [70]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    stride=doc_stride
)

offsets = tokenized_example["offset_mapping"][0]
print(offsets[:10])

[(0, 0), (0, 3), (4, 8), (9, 13), (14, 18), (19, 22), (23, 28), (29, 33), (34, 37), (37, 38)]


In the above output, the very first token (`[CLS]`) has `(0, 0)` because it doesn't correspond to any part of the question/answer.

The second token corresponds to the span from character 0 to 3 in the context, and so on.

In [71]:
token_id = tokenized_example["input_ids"][0][1]
print(tokenizer.convert_ids_to_tokens(token_id))

token_offsets = tokenized_example["offset_mapping"][0][1]
print(example["question"][token_offsets[0]:token_offsets[1]])

how
How


Before going on to the next step, we just have to distinguish between the offsets for `question` and those for `context`. The `sequence_ids` method can be helpful:

In [72]:
sequence_ids = tokenized_example.sequence_ids()

print('len(sequence_ids):', len(sequence_ids))
print(sequence_ids)

len(sequence_ids): 384
[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

It returns None for the special tokens; then `0` for tokens from the first sequence (i.e., the `question`), and `1` for tokens from the second sequence (i.e., the `context`).

It tells us that we need to find the span of answer among all `1` tokens.

Now, we are ready to use `offset_mapping` to find the position of the start and end tokens of the `answer` in a given feature.

In [73]:
answers = example["answers"]
ans_start = answers["answer_start"][0]
ans_end = ans_start + len(answers["text"][0])

print(answers)
print('ans_start:', ans_start)
print('end_char:', ans_end)

{'text': ['over 1,600'], 'answer_start': [30]}
ans_start: 30
end_char: 40


Let `token_start_index` and `token_end_index` be the initial search range for the answer span, initialize them properly:

In [74]:
# Find the position of the first `1` token
### START YOUR CODE ###
token_start_index = 0
for i, (sequence_id) in enumerate(sequence_ids):
    if sequence_id == 1:
        token_start_index = i
        break
### END YOUR CODE ###

print('token_start_index:', token_start_index)
print('offsets[token_start_index]:', offsets[token_start_index])
# Expected output
# token_start_index: 16
# offsets[token_start_index]: (0, 3)

token_start_index: 16
offsets[token_start_index]: (0, 3)


In [75]:
# Find the position of the last `1` token
### START YOUR CODE ###
token_end_index = None
for i, (sequence_id) in reversed(list(enumerate(sequence_ids))):
    if sequence_id == 1:
        token_end_index = i
        break
### END YOUR CODE ###

print('token_end_index:', token_end_index)
print('offsets[token_end_index]:', offsets[token_end_index])
# Expected output
# token_end_index: 382
# offsets[token_end_index]: (1665, 1669)

token_end_index: 382
offsets[token_end_index]: (1665, 1669)


First, detect if `ans_start` and `ans_end` is within the initial search range. 

If they do, then find the start and end indices of tokens, whose offsets encompass `ans_start` and `ans_end`, repectively. 

In [76]:
tokenized_example.keys()

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [77]:
offsets = tokenized_example["offset_mapping"][0]
token_ids = tokenized_example['input_ids'][0]

In [78]:

token_start_index = 16
token_end_index = 382 # reset

# Detect if the answer is within the initial search range
### START YOUR CODE ###
if token_start_index > ans_start or ans_end > token_end_index: 
    # Change `None` to your condition
    print('The answer is not in this feature.')
### END YOUR CODE ###
else:
    # Find the start and end indices of the tokens, whose offsets encompass the ans_start and ans_end
    ### START YOUR CODE ###
    for i in range(token_start_index, token_end_index):
        if offsets[i][0] <= ans_start:
            start_position = i
        if offsets[i][1] >= ans_end:
            end_position = i
            break
    ### END YOUR CODE ###

# Test
print(start_position, end_position)
print(offsets[start_position], offsets[end_position])

# Expected output
# 23 26
# (30,34) (37,40)

23 26
(30, 34) (37, 40)


We can double check that it is indeed the answer:

In [79]:
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
print(answers["text"][0])

over 1, 600
over 1,600
