In [1]:
!pip install transformers datasets



In [2]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [3]:
from datasets import load_dataset
raw_datasets = load_dataset("squad")
raw_datasets

README.md:   0%|          | 0.00/7.62k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [4]:
raw_datasets["train"][1]["title"]

'University_of_Notre_Dame'

In [5]:
raw_datasets["train"][1]["context"]

'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.'

In [6]:
raw_datasets["train"][1]["question"]

'What is in front of the Notre Dame Main Building?'

In [7]:
raw_datasets["train"][1]["answers"]

{'text': ['a copper statue of Christ'], 'answer_start': [188]}

In [8]:
# for train set, ensure that there's always 1 answer
# not multiple answers, or no answers
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

Filter:   0%|          | 0/87599 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

In [9]:
# for validation set, there may be multiple answers
raw_datasets["validation"][2]["answers"]

{'text': ['Santa Clara, California',
  "Levi's Stadium",
  "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."],
 'answer_start': [403, 355, 355]}

In [10]:
# why are there multiple answers?
raw_datasets["validation"][2]["context"]

'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.'

In [11]:
raw_datasets["validation"][2]["question"]

'Where did Super Bowl 50 take place?'

In [12]:
# they may even be the same!
raw_datasets["validation"][0]["answers"]

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 'answer_start': [177, 177, 177]}

In [13]:
from transformers import AutoTokenizer

model_checkpoint = "distilbert-base-cased"
# model_checkpoint = "bert-base-cased" # try it yourself
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]



In [14]:
context = raw_datasets["train"][1]["context"]
question = raw_datasets["train"][1]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] What is in front of the Notre Dame Main Building? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building \' s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

In [15]:
# what if the context is really long?
# split it into multiple samples
inputs = tokenizer(
  question,
  context,
  max_length=100,
  truncation="only_second",
  stride=50,
  return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
  print(tokenizer.decode(ids))

[CLS] What is in front of the Notre Dame Main Building? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building ' s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the G [SEP]
[CLS] What is in front of the Notre Dame Main Building? [SEP] facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernade [SEP]
[CLS] What is in front of the Notre Dame Main Building? [SEP] of the Sacred Heart. Immediately behind the basilica is the Grotto

In [16]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'overflow_to_sample_mapping'])

In [17]:
# what's the new key?
inputs['overflow_to_sample_mapping']

[0, 0, 0, 0]

In [18]:
inputs = tokenizer(
    raw_datasets["train"][:3]["question"],
    raw_datasets["train"][:3]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs['overflow_to_sample_mapping']

[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]

In [19]:
# it points to the original sample index
for ids in inputs["input_ids"]:
  print(tokenizer.decode(ids))

[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building ' s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the

In [20]:
# recreate inputs for just a single context-question pair
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [21]:

# what is this (weirdly named) offset_mapping?
# it tells us the location of each token
# notes:
# special tokens take up 0 space - (0, 0)
# the question portion is the same for each sample
# the context portion starting point inceases in each sample
inputs['offset_mapping']

[[(0, 0),
  (0, 4),
  (5, 7),
  (8, 10),
  (11, 16),
  (17, 19),
  (20, 23),
  (24, 29),
  (30, 34),
  (35, 39),
  (40, 48),
  (48, 49),
  (0, 0),
  (0, 13),
  (13, 15),
  (15, 16),
  (17, 20),
  (21, 27),
  (28, 31),
  (32, 33),
  (34, 42),
  (43, 52),
  (52, 53),
  (54, 56),
  (56, 58),
  (59, 62),
  (63, 67),
  (68, 76),
  (76, 77),
  (77, 78),
  (79, 83),
  (84, 88),
  (89, 91),
  (92, 93),
  (94, 100),
  (101, 107),
  (108, 110),
  (111, 114),
  (115, 121),
  (122, 126),
  (126, 127),
  (128, 139),
  (140, 142),
  (143, 148),
  (149, 151),
  (152, 155),
  (156, 160),
  (161, 169),
  (170, 173),
  (174, 180),
  (181, 183),
  (183, 184),
  (185, 187),
  (188, 189),
  (190, 196),
  (197, 203),
  (204, 206),
  (207, 213),
  (214, 218),
  (219, 223),
  (224, 226),
  (226, 229),
  (229, 232),
  (233, 237),
  (238, 241),
  (242, 248),
  (249, 250),
  (250, 251),
  (251, 254),
  (254, 256),
  (257, 259),
  (260, 262),
  (263, 264),
  (264, 265),
  (265, 268),
  (268, 269),
  (269, 270),
 

In [22]:
len(inputs['offset_mapping'])

4

In [23]:
len(inputs['offset_mapping'][0])

100

In [24]:
print(inputs.sequence_ids(0))

[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]


In [25]:
# problem: the position of the answer will change in each
# window of the context
# the answer is also the target for the neural network
# how can we recompute the targets for each context window?

# since we took the question and context from this sample earlier
answer = raw_datasets["train"][1]["answers"]
answer

{'text': ['a copper statue of Christ'], 'answer_start': [188]}

In [26]:
type(inputs.sequence_ids(0))

list

Extract for Char to token index 

In [27]:
# find the start and end of the context (the first and last '1')
sequence_ids = inputs.sequence_ids(0)

ctx_start = sequence_ids.index(1) # first occurrence
ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1 # last occurrence

ctx_start, ctx_end

(13, 98)

In [28]:
# check whether or not the answer is fully contained within the context
# if not, target is (start, end) = (0, 0)

ans_start_char = answer['answer_start'][0]
ans_end_char = ans_start_char + len(answer['text'][0])

offset = inputs['offset_mapping'][0]

start_idx = 0
end_idx = 0

if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
  print("target is (0, 0)")
  # nothing else to do
else:
  # find the start and end TOKEN positions

  # the 'trick' is knowing what is in units of tokens and what is in
  # units of characters

  # recall: the offset_mapping contains the character positions of each token

  i = ctx_start
  for start_end_char in offset[ctx_start:]:
    start, end = start_end_char
    if start == ans_start_char:
      start_idx = i
      # don't break yet
    
    if end == ans_end_char:
      end_idx = i
      break

    i += 1

start_idx, end_idx

(53, 57)

In [29]:
# check
input_ids = inputs['input_ids'][0]
input_ids[start_idx : end_idx + 1]

[170, 7335, 5921, 1104, 4028]

In [30]:
tokenizer.decode(input_ids[start_idx : end_idx + 1])

'a copper statue of Christ'

In [31]:
def find_answer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset):
  
  start_idx = 0
  end_idx = 0

  if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
    pass
    # print("target is (0, 0)")
    # nothing else to do
  else:
    # find the start and end TOKEN positions

    # the 'trick' is knowing what is in units of tokens and what is in
    # units of characters

    # recall: the offset_mapping contains the character positions of each token

    i = ctx_start
    for start_end_char in offset[ctx_start:]:
      start, end = start_end_char
      if start == ans_start_char:
        start_idx = i
        # don't break yet
      
      if end == ans_end_char:
        end_idx = i
        break

      i += 1
  return start_idx, end_idx

In [32]:
# try it on all context windows
# sometimes, the answer won't appear!

start_idxs = []
end_idxs = []

for i, offset in enumerate(inputs['offset_mapping']):
  # the final window may not be full size - can't assume 100
  sequence_ids = inputs.sequence_ids(i)

  # find start + end of context (first 1 and last 1)
  ctx_start = sequence_ids.index(1)
  ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

  start_idx, end_idx = find_answer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset)

  start_idxs.append(start_idx)
  end_idxs.append(end_idx)

start_idxs, end_idxs

([53, 17, 0, 0], [57, 21, 0, 0])

👆🏼 Token for train Neural Network

In [33]:
# some questions have leading and/or trailing whitespace
for q in raw_datasets["train"]["question"][:1000]:
  if q.strip() != q:
    print(q)

In what city and state did Beyonce  grow up? 
 The album, Dangerously in Love  achieved what spot on the Billboard Top 100 chart?
Which song did Beyonce sing at the first couple's inaugural ball? 
What event did Beyoncé perform at one month after Obama's inauguration? 
Where was the album released? 
What movie influenced Beyonce towards empowerment themes? 


In [34]:
# now we are ready to process (tokenize) the training data
# (i.e. expand question+context pairs into question+smaller context windows)

# Google used 384 for SQuAD
max_length = 384
stride = 128

def tokenize_fn_train(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  offset_mapping = inputs.pop("offset_mapping")
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
  answers = batch['answers']
  start_idxs, end_idxs = [], []

  # same loop as above
  for i, offset in enumerate(offset_mapping):
    sample_idx = orig_sample_idxs[i]
    answer = answers[sample_idx]

    ans_start_char = answer['answer_start'][0]
    ans_end_char = ans_start_char + len(answer['text'][0])

    sequence_ids = inputs.sequence_ids(i)

    # find start + end of context (first 1 and last 1)
    ctx_start = sequence_ids.index(1)
    ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

    start_idx, end_idx = find_answer_token_idx(
      ctx_start,
      ctx_end,
      ans_start_char,
      ans_end_char,
      offset)

    start_idxs.append(start_idx)
    end_idxs.append(end_idx)
  
  inputs["start_positions"] = start_idxs
  inputs["end_positions"] = end_idxs
  return inputs

In [35]:
train_dataset = raw_datasets["train"].map(
  tokenize_fn_train,
  batched=True,
  remove_columns=raw_datasets["train"].column_names,
)
len(raw_datasets["train"]), len(train_dataset)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

(87599, 88729)

In [36]:
# note: we'll keep these IDs for later
raw_datasets["validation"][0]

{'id': '56be4db0acb8001400a502ec',
 'title': 'Super_Bowl_50',
 'context': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.',
 'question': 'Which NFL team represented the AFC at Super Bowl 50?',
 'answers': {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],


In [37]:
# tokenize the validation set differently
# we won't need the targets since we will just compare with the original answer
# also: overwrite offset_mapping with Nones in place of question
def tokenize_fn_validation(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
  sample_ids = []

  # rewrite offset mapping by replacing question tuples with None
  # this will be helpful later on when we compute metrics
  for i in range(len(inputs["input_ids"])):
    sample_idx = orig_sample_idxs[i]
    sample_ids.append(batch['id'][sample_idx])

    sequence_ids = inputs.sequence_ids(i)
    offset = inputs["offset_mapping"][i]
    inputs["offset_mapping"][i] = [
      x if sequence_ids[j] == 1 else None for j, x in enumerate(offset)]
    
  inputs['sample_id'] = sample_ids
  return inputs

In [38]:
validation_dataset = raw_datasets["validation"].map(
  tokenize_fn_validation,
  batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)
len(raw_datasets["validation"]), len(validation_dataset)

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

(10570, 10822)

# Metric

In [39]:
!pip install evaluate

  pid, fd = os.forkpty()
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3


In [40]:
from evaluate import load

metric = load("squad")

Downloading builder script:   0%|          | 0.00/4.53k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.32k [00:00<?, ?B/s]

In [41]:
predicted_answers = [
  {'id': '1', 'prediction_text': 'Albert Einstein'},
  {'id': '2', 'prediction_text': 'physicist'},
  {'id': '3', 'prediction_text': 'general relativity'},
]
true_answers = [
  {'id': '1', 'answers': {'text': ['Albert Einstein'], 'answer_start': [100]}},
  {'id': '2', 'answers': {'text': ['physicist'], 'answer_start': [100]}},
  {'id': '3', 'answers': {'text': ['special relativity'], 'answer_start': [100]}},
]

# id and answer_start seem superfluous but you'll get an error if not included
# exercise: remove them (one at a time) and see!
metric.compute(predictions=predicted_answers, references=true_answers)

{'exact_match': 66.66666666666667, 'f1': 83.33333333333333}

In [42]:
# next problem: how to go from logits to prediction text?
# to make it easier, let's work on an already-trained question-answering model
small_validation_dataset = raw_datasets["validation"].select(range(100))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer2 = AutoTokenizer.from_pretrained(trained_checkpoint)

# temporarily assign tokenizer2 to tokenizer since it's used as a global
# in tokenize_fn_validation
old_tokenizer = tokenizer
tokenizer = tokenizer2

small_validation_processed = small_validation_dataset.map(
    tokenize_fn_validation,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

# change it back
tokenizer = old_tokenizer

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/473 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]



Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [43]:
# get the model outputs
import torch
from transformers import AutoModelForQuestionAnswering

# the trained model doesn't use these columns
small_model_inputs = small_validation_processed.remove_columns(
  ["sample_id", "offset_mapping"])
small_model_inputs.set_format("torch")

# get gpu device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# move tensors to gpu device
small_model_inputs_gpu = {
  k: small_model_inputs[k].to(device) for k in small_model_inputs.column_names
}

# download the model
trained_model = AutoModelForQuestionAnswering.from_pretrained(
  trained_checkpoint).to(device)

# get the model outputs
with torch.no_grad():
  outputs = trained_model(**small_model_inputs_gpu)

model.safetensors:   0%|          | 0.00/261M [00:00<?, ?B/s]

In [44]:
outputs

QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[ -2.2607,  -5.1783,  -5.2709,  ...,  -9.5243,  -9.5183,  -9.5288],
        [ -2.5961,  -5.5482,  -5.5313,  ...,  -9.9598,  -9.9533,  -9.9860],
        [ -3.7127,  -7.1848,  -8.5388,  ..., -11.6557, -11.6571, -11.6505],
        ...,
        [ -2.0260,  -4.4167,  -4.4980,  ...,  -8.1479,  -8.1530,  -8.1760],
        [ -4.1553,  -5.8304,  -7.1643,  ..., -10.5255, -10.5251, -10.4890],
        [ -3.2000,  -5.8162,  -6.7249,  ...,  -9.4935,  -9.5038,  -9.4871]],
       device='cuda:0'), end_logits=tensor([[ -0.7353,  -4.9236,  -5.1048,  ...,  -8.8734,  -8.8916,  -8.8550],
        [ -1.3056,  -5.3870,  -5.4945,  ...,  -9.4895,  -9.5039,  -9.4958],
        [ -2.7649,  -7.2201,  -9.0916,  ..., -11.3106, -11.3414, -11.2702],
        ...,
        [ -0.0768,  -4.8210,  -4.4374,  ...,  -8.0483,  -8.0502,  -7.9903],
        [ -2.7347,  -5.3650,  -7.2549,  ..., -10.0498, -10.0661,  -9.9886],
        [ -1.0991,  -4.2569,  -6.1267,  ...,  -8

In [45]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

In [46]:
small_validation_processed['sample_id'][:5]

['56be4db0acb8001400a502ec',
 '56be4db0acb8001400a502ed',
 '56be4db0acb8001400a502ee',
 '56be4db0acb8001400a502ef',
 '56be4db0acb8001400a502f0']

In [47]:
len(validation_dataset['sample_id'])

10822

In [48]:
len(set(validation_dataset['sample_id']))

10570

In [49]:
# example: {'56be4db0acb8001400a502ec': [0, 1, 2, 3], ...}
sample_id2idxs = {}
for i, id_ in enumerate(small_validation_processed['sample_id']):
  if id_ not in sample_id2idxs:
    sample_id2idxs[id_] = [i]
  else:
    print("here")
    sample_id2idxs[id_].append(i)

In [50]:
start_logits.shape, end_logits.shape

((100, 384), (100, 384))

In [51]:
# reminder of how to find indices with the largest values
(-start_logits[0]).argsort()

array([ 46,  57,  47,  38,  39,  58,  50,  43,  45,  54,  56,  49,  13,
        42,  40,  35,  27,  31,  48,  41,  53,  44,  37,  59,  78,  15,
         0,  52,  24,  65,  81,  70,  18,  51,  55,  26,  69,  29,  28,
        75,  61,  64,  23,  36,  32,  11, 101,  62,  66,  34,  95,  30,
        63,  21,  19,  20,  17,  14,  22,  33,  68,  87, 171,  12,  76,
        71,  73,  92, 110,  84, 151,   1,  74,   2,   6,  16,  80,  79,
       105,  98,  10,  96, 136, 169, 106, 100,  93, 165,  67, 109,   8,
        90,   3, 115,  60,   5,  97,   7, 103, 102,  86,  72, 111,  89,
       108,   4,  88,  25, 132,  77, 123, 150, 124, 153,  83, 118,  82,
        85, 107, 114, 143, 164, 137, 130, 166, 159, 131,  91,   9, 144,
       139, 160,  94, 141, 128, 112, 134, 152, 170, 154, 117, 127, 104,
       140, 157, 155, 133, 145, 119, 162, 138, 135, 156, 167, 168, 126,
       148, 163, 161, 116,  99, 120, 142, 158, 125, 146, 113, 121, 147,
       149, 129, 122, 311, 312, 304, 309, 313, 310, 300, 307, 31

In [52]:
start_logits[0][(-start_logits[0]).argsort()]

array([10.694445  ,  9.803685  ,  4.459973  ,  4.400487  ,  2.9437785 ,
        2.7017367 ,  2.0126448 ,  1.5780739 ,  0.52237445,  0.02073721,
       -0.02802688, -0.04971648, -0.38573122, -0.6945363 , -0.7979508 ,
       -0.86780477, -0.87220925, -1.3516886 , -1.3703715 , -1.3878827 ,
       -1.5135094 , -1.7355472 , -1.8827027 , -1.8932863 , -1.9078972 ,
       -1.9304978 , -2.2607322 , -2.2983866 , -2.3069332 , -2.5027428 ,
       -2.510063  , -2.530842  , -2.5399983 , -2.6718144 , -2.732354  ,
       -2.7710216 , -2.7713673 , -2.9521358 , -3.0604653 , -3.1706066 ,
       -3.204542  , -3.569336  , -3.5798059 , -3.6668851 , -3.7250628 ,
       -3.7498565 , -3.7632205 , -3.996814  , -4.0113316 , -4.0688004 ,
       -4.0944853 , -4.195475  , -4.2383103 , -4.3323617 , -4.352419  ,
       -4.3879614 , -4.38861   , -4.396615  , -4.6790547 , -4.7030315 ,
       -4.7757587 , -4.7778134 , -4.788218  , -4.7882495 , -4.8221273 ,
       -4.872539  , -4.8849363 , -4.8981495 , -5.072096  , -5.10

In [53]:
# reminder: in offset_mapping we store None everywhere except the context window
# in the context window we store tuples for each token containing:
# (start_character_position, end_character_position)
small_validation_processed['offset_mapping'][0]

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 [0, 5],
 [6, 10],
 [11, 13],
 [14, 17],
 [18, 20],
 [21, 29],
 [30, 38],
 [39, 43],
 [44, 46],
 [47, 56],
 [57, 60],
 [61, 69],
 [70, 72],
 [73, 76],
 [77, 85],
 [86, 94],
 [95, 101],
 [102, 103],
 [103, 106],
 [106, 107],
 [108, 111],
 [112, 115],
 [116, 120],
 [121, 127],
 [127, 128],
 [129, 132],
 [133, 141],
 [142, 150],
 [151, 161],
 [162, 163],
 [163, 166],
 [166, 167],
 [168, 176],
 [177, 183],
 [184, 191],
 [192, 200],
 [201, 204],
 [205, 213],
 [214, 222],
 [223, 233],
 [234, 235],
 [235, 238],
 [238, 239],
 [240, 248],
 [249, 257],
 [258, 266],
 [267, 269],
 [269, 270],
 [270, 272],
 [273, 275],
 [276, 280],
 [281, 286],
 [287, 292],
 [293, 298],
 [299, 303],
 [304, 309],
 [309, 310],
 [311, 314],
 [315, 319],
 [320, 323],
 [324, 330],
 [331, 333],
 [334, 342],
 [343, 344],
 [344, 345],
 [346, 350],
 [350, 351],
 [352, 354],
 [355, 359],
 [359, 360],
 [360, 361],
 [362, 369],
 [370, 37

In [54]:
n_largest = 20
max_answer_length = 30
predicted_answers = []

# we are looping through the original (untokenized) dataset
# because we need to grab the answer from the original string context
for sample in small_validation_dataset:
  sample_id = sample['id']
  context = sample['context']

  # update these as we loop through candidate answers
  best_score = float('-inf')
  best_answer = None

  # now loop through the *expanded* input samples (fixed size context windows)
  # from here we will pick the highest probability start/end combination
  for idx in sample_id2idxs[sample_id]:
    start_logit = start_logits[idx] # (384,) vector
    end_logit = end_logits[idx] # (384,) vector
    offsets = small_validation_processed[idx]['offset_mapping']

    start_indices = (-start_logit).argsort()
    end_indices = (-end_logit).argsort()

    for start_idx in start_indices[:n_largest]:
      for end_idx in end_indices[:n_largest]:

        # skip answers not contained in context window
        # recall: we set entries not pertaining to context to None earlier
        if offsets[start_idx] is None or offsets[end_idx] is None:
          continue
        
        # skip answers where end < start
        if end_idx < start_idx:
          continue
        
        # skip answers that are too long
        if end_idx - start_idx + 1 > max_answer_length:
          continue
        
        # see theory lecture for score calculation
        score = start_logit[start_idx] + end_logit[end_idx]
        if score > best_score:
          best_score = score

          # find positions of start and end characters
          # recall: offsets contains tuples for each token:
          # (start_char, end_char)
          first_ch = offsets[start_idx][0]
          last_ch = offsets[end_idx][1]

          best_answer = context[first_ch:last_ch]

  # save best answer
  predicted_answers.append({'id': sample_id, 'prediction_text': best_answer})

In [55]:
small_validation_dataset['answers'][0]

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 'answer_start': [177, 177, 177]}

In [56]:
# now test it!

true_answers = [
  {'id': x['id'], 'answers': x['answers']} for x in small_validation_dataset
]
metric.compute(predictions=predicted_answers, references=true_answers)

{'exact_match': 83.0, 'f1': 88.25000000000004}

In [57]:
# now let's define a full compute_metrics function
# note: this will NOT be called from the trainer

from tqdm.autonotebook import tqdm

def compute_metrics(start_logits, end_logits, processed_dataset, orig_dataset):
  # map sample_id ('56be4db0acb8001400a502ec') to row indices of processed data
  sample_id2idxs = {}
  for i, id_ in enumerate(processed_dataset['sample_id']):
    if id_ not in sample_id2idxs:
      sample_id2idxs[id_] = [i]
    else:
      sample_id2idxs[id_].append(i)

  predicted_answers = []
  for sample in tqdm(orig_dataset):

    sample_id = sample['id']
    context = sample['context']

    # update these as we loop through candidate answers
    best_score = float('-inf')
    best_answer = None

    # now loop through the *expanded* input samples (fixed size context windows)
    # from here we will pick the highest probability start/end combination
    for idx in sample_id2idxs[sample_id]:
      start_logit = start_logits[idx] # (T,) vector
      end_logit = end_logits[idx] # (T,) vector

      # note: do NOT do the reverse: ['offset_mapping'][idx]
      offsets = processed_dataset[idx]['offset_mapping']

      start_indices = (-start_logit).argsort()
      end_indices = (-end_logit).argsort()

      for start_idx in start_indices[:n_largest]:
        for end_idx in end_indices[:n_largest]:

          # skip answers not contained in context window
          # recall: we set entries not pertaining to context to None earlier
          if offsets[start_idx] is None or offsets[end_idx] is None:
            continue
          
          # skip answers where end < start
          if end_idx < start_idx:
            continue
          
          # skip answers that are too long
          if end_idx - start_idx + 1 > max_answer_length:
            continue
          
          # see theory lecture for score calculation
          score = start_logit[start_idx] + end_logit[end_idx]
          if score > best_score:
            best_score = score

            # find positions of start and end characters
            # recall: offsets contains tuples for each token:
            # (start_char, end_char)
            first_ch = offsets[start_idx][0]
            last_ch = offsets[end_idx][1]

            best_answer = context[first_ch:last_ch]

    # save best answer
    predicted_answers.append({'id': sample_id, 'prediction_text': best_answer})
  
  # compute the metrics
  true_answers = [
    {'id': x['id'], 'answers': x['answers']} for x in orig_dataset
  ]
  return metric.compute(predictions=predicted_answers, references=true_answers)

In [58]:
# run our function on the same mini dataset as before
compute_metrics(
    start_logits,
    end_logits,
    small_validation_processed,
    small_validation_dataset,
)

  0%|          | 0/100 [00:00<?, ?it/s]

{'exact_match': 83.0, 'f1': 88.25000000000004}

# Train and Evaluate

In [59]:
# now load the model we want to fine-tune
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

model.safetensors:   0%|          | 0.00/263M [00:00<?, ?B/s]

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [60]:
from transformers import TrainingArguments

args = TrainingArguments(
    "finetuned-squad",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [61]:
from transformers import Trainer

# takes ~2.5h with bert on full dataset
# ~1h 15min with distilbert

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    # train_dataset=train_dataset.shuffle(seed=42).select(range(1_000)),
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)
trainer.train()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Step,Training Loss
500,3.4866
1000,2.3204
1500,1.9965
2000,1.7501
2500,1.6528
3000,1.5394
3500,1.5064
4000,1.4323
4500,1.3752
5000,1.4113


TrainOutput(global_step=33276, training_loss=1.0458161709099123, metrics={'train_runtime': 5826.5998, 'train_samples_per_second': 45.685, 'train_steps_per_second': 5.711, 'total_flos': 2.608361755366349e+16, 'train_loss': 1.0458161709099123, 'epoch': 3.0})

In [62]:
trainer_output = trainer.predict(validation_dataset)

In [63]:
type(trainer_output)

transformers.trainer_utils.PredictionOutput

In [64]:
trainer_output

PredictionOutput(predictions=(array([[ -7.078125 , -10.8515625, -10.7578125, ..., -11.6171875,
        -11.6171875, -11.6328125],
       [ -7.2851562, -10.8515625, -10.78125  , ..., -11.6015625,
        -11.6015625, -11.609375 ],
       [ -7.4921875, -11.0859375, -11.375    , ..., -11.625    ,
        -11.609375 , -11.6328125],
       ...,
       [ -4.71875  , -11.1640625, -11.6796875, ..., -11.6015625,
        -11.59375  , -11.6171875],
       [ -3.9726562, -10.5546875, -10.9453125, ..., -11.515625 ,
        -11.484375 , -11.5078125],
       [ -4.5234375, -10.9765625, -11.5859375, ..., -11.546875 ,
        -11.5390625, -11.5625   ]], dtype=float32), array([[ -6.5351562, -10.6171875,  -9.9921875, ..., -11.7265625,
        -11.734375 , -11.7265625],
       [ -6.734375 , -10.59375  ,  -9.984375 , ..., -11.7265625,
        -11.7421875, -11.734375 ],
       [ -6.828125 , -10.7421875, -11.       , ..., -11.7421875,
        -11.765625 , -11.734375 ],
       ...,
       [ -4.578125 , -11.0390

In [65]:
predictions, _, _ = trainer_output

In [66]:
predictions

(array([[ -7.078125 , -10.8515625, -10.7578125, ..., -11.6171875,
         -11.6171875, -11.6328125],
        [ -7.2851562, -10.8515625, -10.78125  , ..., -11.6015625,
         -11.6015625, -11.609375 ],
        [ -7.4921875, -11.0859375, -11.375    , ..., -11.625    ,
         -11.609375 , -11.6328125],
        ...,
        [ -4.71875  , -11.1640625, -11.6796875, ..., -11.6015625,
         -11.59375  , -11.6171875],
        [ -3.9726562, -10.5546875, -10.9453125, ..., -11.515625 ,
         -11.484375 , -11.5078125],
        [ -4.5234375, -10.9765625, -11.5859375, ..., -11.546875 ,
         -11.5390625, -11.5625   ]], dtype=float32),
 array([[ -6.5351562, -10.6171875,  -9.9921875, ..., -11.7265625,
         -11.734375 , -11.7265625],
        [ -6.734375 , -10.59375  ,  -9.984375 , ..., -11.7265625,
         -11.7421875, -11.734375 ],
        [ -6.828125 , -10.7421875, -11.       , ..., -11.7421875,
         -11.765625 , -11.734375 ],
        ...,
        [ -4.578125 , -11.0390625, -10.

In [67]:
start_logits, end_logits = predictions

In [68]:
compute_metrics(
    start_logits,
    end_logits,
    validation_dataset, # processed
    raw_datasets["validation"], # orig
)

  0%|          | 0/10570 [00:00<?, ?it/s]

{'exact_match': 77.06717123935667, 'f1': 85.17397963585955}

In [69]:
trainer.save_model('my_saved_model')

In [70]:
from transformers import pipeline

qa = pipeline(
  "question-answering",
  model='my_saved_model',
  device=0,
)

In [71]:
context = "Today I went to the store to purchase a carton of milk."
question = "What did I buy?"

qa(context=context, question=question)

{'score': 0.7199029922485352,
 'start': 38,
 'end': 54,
 'answer': 'a carton of milk'}