In [1]:
from datasets import load_dataset
raw_datasets = load_dataset("squad")
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

### Explore the dataset

In [2]:
print (raw_datasets["train"][1]["title"])
print (raw_datasets["train"][1]["question"])
print (raw_datasets["train"][1]["answers"])
print (raw_datasets["train"][1]["context"])

University_of_Notre_Dame
What is in front of the Notre Dame Main Building?
{'text': ['a copper statue of Christ'], 'answer_start': [188]}
Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.


In [3]:
print (len(raw_datasets['train']))

# Filter out questions with more or less than one answer
# (just in case, there are none)
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)
print (len(raw_datasets['train']))


87599
87599


In [4]:
print (raw_datasets["validation"][2]["title"])
print (raw_datasets["validation"][2]["question"])
print (raw_datasets["validation"][2]["answers"])
print (raw_datasets["validation"][2]["context"])

Super_Bowl_50
Where did Super Bowl 50 take place?
{'text': ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."], 'answer_start': [403, 355, 355]}
Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.


### Tokenizer

In [5]:
from transformers import AutoTokenizer

model_checkpoint = "distilbert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


In [6]:
context = raw_datasets["train"][1]["context"]
question = raw_datasets["train"][1]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] What is in front of the Notre Dame Main Building? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

In [7]:
# Dividing input into samples
inputs = tokenizer(
  question,
  context,
  max_length=100,
  truncation="only_second",
  stride=50,
  return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
  print(tokenizer.decode(ids))

[CLS] What is in front of the Notre Dame Main Building? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the G [SEP]
[CLS] What is in front of the Notre Dame Main Building? [SEP] facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernade [SEP]
[CLS] What is in front of the Notre Dame Main Building? [SEP] of the Sacred Heart. Immediately behind the basilica is the Grotto, 

In [8]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'overflow_to_sample_mapping'])

In [9]:
# This refers to indices of original samples. Here we passed only
# one sample, so all the indices are 0.
inputs['overflow_to_sample_mapping']

[0, 0, 0, 0]

In [10]:
# Let's pass more samples to the tokenizer and see what happens
inputs = tokenizer(
    raw_datasets["train"][:3]["question"],
    raw_datasets["train"][:3]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

for ids in inputs["input_ids"]:
  print(tokenizer.decode(ids))
  
inputs['overflow_to_sample_mapping']

[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the B

[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]

In [11]:
# recreate inputs for just a single context-question pair
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [12]:
# offset_mapping tells us the location of each token
# note: special tokens take up 0 space - (0, 0)
print (inputs['offset_mapping'])

[[(0, 0), (0, 4), (5, 7), (8, 10), (11, 16), (17, 19), (20, 23), (24, 29), (30, 34), (35, 39), (40, 48), (48, 49), (0, 0), (0, 13), (13, 15), (15, 16), (17, 20), (21, 27), (28, 31), (32, 33), (34, 42), (43, 52), (52, 53), (54, 56), (56, 58), (59, 62), (63, 67), (68, 76), (76, 77), (77, 78), (79, 83), (84, 88), (89, 91), (92, 93), (94, 100), (101, 107), (108, 110), (111, 114), (115, 121), (122, 126), (126, 127), (128, 139), (140, 142), (143, 148), (149, 151), (152, 155), (156, 160), (161, 169), (170, 173), (174, 180), (181, 183), (183, 184), (185, 187), (188, 189), (190, 196), (197, 203), (204, 206), (207, 213), (214, 218), (219, 223), (224, 226), (226, 229), (229, 232), (233, 237), (238, 241), (242, 248), (249, 250), (250, 251), (251, 254), (254, 256), (257, 259), (260, 262), (263, 264), (264, 265), (265, 268), (268, 269), (269, 270), (271, 275), (276, 278), (279, 282), (283, 287), (288, 296), (297, 299), (300, 303), (304, 312), (313, 315), (316, 319), (320, 326), (327, 332), (332, 333

In [13]:
print ('Nr splits:', len(inputs['offset_mapping']))
print ([f'Nr tokens in {i} split: {len(x)}. ' for i, x in enumerate(inputs['offset_mapping'])])

Nr splits: 4
['Nr tokens in 0 split: 100. ', 'Nr tokens in 1 split: 100. ', 'Nr tokens in 2 split: 100. ', 'Nr tokens in 3 split: 69. ']


In [14]:
def find_answer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset):
  
  start_idx = 0
  end_idx = 0

  if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
    pass
    # print("target is (0, 0)")
    # nothing else to do
  else:
    # find the start and end TOKEN positions

    # the 'trick' is knowing what is in units of tokens and what is in
    # units of characters

    # recall: the offset_mapping contains the character positions of each token

    i = ctx_start
    for start_end_char in offset[ctx_start:]:
      start, end = start_end_char
      if start == ans_start_char:
        start_idx = i
        # don't break yet
      
      if end == ans_end_char:
        end_idx = i
        break

      i += 1
  return start_idx, end_idx

In [15]:
# now we are ready to process (tokenize) the training data
# (i.e. expand question+context pairs into question+smaller context windows)

# Google used 384 for SQuAD
max_length = 384
stride = 128

def tokenize_fn_train(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  offset_mapping = inputs.pop("offset_mapping")
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
  answers = batch['answers']
  start_idxs, end_idxs = [], []

  # same loop as above
  for i, offset in enumerate(offset_mapping):
    sample_idx = orig_sample_idxs[i]
    answer = answers[sample_idx]

    ans_start_char = answer['answer_start'][0]
    ans_end_char = ans_start_char + len(answer['text'][0])

    sequence_ids = inputs.sequence_ids(i)

    # find start + end of context (first 1 and last 1)
    ctx_start = sequence_ids.index(1)
    ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

    start_idx, end_idx = find_answer_token_idx(
      ctx_start,
      ctx_end,
      ans_start_char,
      ans_end_char,
      offset)

    start_idxs.append(start_idx)
    end_idxs.append(end_idx)
  
  inputs["start_positions"] = start_idxs
  inputs["end_positions"] = end_idxs
  return inputs

In [16]:
train_dataset = raw_datasets["train"].map(
  tokenize_fn_train,
  batched=True,
  remove_columns=raw_datasets["train"].column_names,
)
len(raw_datasets["train"]), len(train_dataset)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

(87599, 88729)

In [17]:
# tokenize the validation set differently
# we won't need the targets since we will just compare with the original answer
# also: overwrite offset_mapping with Nones in place of question
def tokenize_fn_validation(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
  sample_ids = []

  # rewrite offset mapping by replacing question tuples with None
  # this will be helpful later on when we compute metrics
  for i in range(len(inputs["input_ids"])):
    sample_idx = orig_sample_idxs[i]
    sample_ids.append(batch['id'][sample_idx])

    sequence_ids = inputs.sequence_ids(i)
    offset = inputs["offset_mapping"][i]
    inputs["offset_mapping"][i] = [
      x if sequence_ids[j] == 1 else None for j, x in enumerate(offset)]
    
  inputs['sample_id'] = sample_ids
  return inputs

validation_dataset = raw_datasets["validation"].map(
  tokenize_fn_validation,
  batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)
len(raw_datasets["validation"]), len(validation_dataset)

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

(10570, 10822)

### Metrics

In [19]:
from datasets import load_metric

metric = load_metric("squad")

predicted_answers = [
  {'id': '1', 'prediction_text': 'Albert Einstein'},
  {'id': '2', 'prediction_text': 'physicist'},
  {'id': '3', 'prediction_text': 'general relativity'},
]
true_answers = [
  {'id': '1', 'answers': {'text': ['Albert Einstein'], 'answer_start': [100]}},
  {'id': '2', 'answers': {'text': ['physicist'], 'answer_start': [100]}},
  {'id': '3', 'answers': {'text': ['special relativity'], 'answer_start': [100]}},
]

# id and answer_start seem superfluous but you'll get an error if not included
# exercise: remove them (one at a time) and see!
metric.compute(predictions=predicted_answers, references=true_answers)

{'exact_match': 66.66666666666667, 'f1': 83.33333333333333}

In [20]:
# next problem: how to go from logits to prediction text?
# to make it easier, let's work on an already-trained question-answering model
small_validation_dataset = raw_datasets["validation"].select(range(100))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer2 = AutoTokenizer.from_pretrained(trained_checkpoint)

# temporarily assign tokenizer2 to tokenizer since it's used as a global
# in tokenize_fn_validation
old_tokenizer = tokenizer
tokenizer = tokenizer2

small_validation_processed = small_validation_dataset.map(
    tokenize_fn_validation,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

# change it back
tokenizer = old_tokenizer

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/473 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [21]:
# get the model outputs
import torch
from transformers import AutoModelForQuestionAnswering

# the trained model doesn't use these columns
small_model_inputs = small_validation_processed.remove_columns(
  ["sample_id", "offset_mapping"])
small_model_inputs.set_format("torch")

# get gpu device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# move tensors to gpu device
small_model_inputs_gpu = {
  k: small_model_inputs[k].to(device) for k in small_model_inputs.column_names
}

# download the model
trained_model = AutoModelForQuestionAnswering.from_pretrained(
  trained_checkpoint).to(device)

# get the model outputs
with torch.no_grad():
  outputs = trained_model(**small_model_inputs_gpu)

Downloading model.safetensors:   0%|          | 0.00/261M [00:00<?, ?B/s]

In [22]:
outputs

QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[ -2.2607,  -5.1783,  -5.2709,  ...,  -9.5243,  -9.5183,  -9.5288],
        [ -2.5961,  -5.5482,  -5.5313,  ...,  -9.9598,  -9.9533,  -9.9860],
        [ -3.7127,  -7.1848,  -8.5388,  ..., -11.6557, -11.6571, -11.6505],
        ...,
        [ -2.0260,  -4.4167,  -4.4980,  ...,  -8.1479,  -8.1530,  -8.1760],
        [ -4.1553,  -5.8304,  -7.1643,  ..., -10.5255, -10.5251, -10.4890],
        [ -3.2000,  -5.8162,  -6.7249,  ...,  -9.4935,  -9.5038,  -9.4871]],
       device='cuda:0'), end_logits=tensor([[ -0.7353,  -4.9236,  -5.1048,  ...,  -8.8734,  -8.8916,  -8.8550],
        [ -1.3056,  -5.3870,  -5.4945,  ...,  -9.4895,  -9.5039,  -9.4959],
        [ -2.7649,  -7.2201,  -9.0916,  ..., -11.3106, -11.3414, -11.2702],
        ...,
        [ -0.0768,  -4.8210,  -4.4374,  ...,  -8.0483,  -8.0502,  -7.9903],
        [ -2.7347,  -5.3650,  -7.2549,  ..., -10.0498, -10.0661,  -9.9886],
        [ -1.0991,  -4.2569,  -6.1267,  ...,  -8

In [23]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

small_validation_processed['sample_id'][:5]

['56be4db0acb8001400a502ec',
 '56be4db0acb8001400a502ed',
 '56be4db0acb8001400a502ee',
 '56be4db0acb8001400a502ef',
 '56be4db0acb8001400a502f0']

In [25]:
print (len(validation_dataset['sample_id']))
print (len(set(validation_dataset['sample_id'])))

10822
10570


In [26]:
# example: {'56be4db0acb8001400a502ec': [0, 1, 2, 3], ...}
sample_id2idxs = {}
for i, id_ in enumerate(small_validation_processed['sample_id']):
  if id_ not in sample_id2idxs:
    sample_id2idxs[id_] = [i]
  else:
    print("here")
    sample_id2idxs[id_].append(i)

In [27]:
start_logits.shape, end_logits.shape

((100, 384), (100, 384))

### Not finished