In [92]:
from datasets import load_dataset
raw_datasets = load_dataset("squad")
raw_datasets


KeyboardInterrupt: 

In [None]:
# for train set, ensure that there's always 1 answer
# not multiple answers, or no answers
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

In [None]:
from transformers import AutoTokenizer

model_checkpoint = "distilbert-base-cased"
# model_checkpoint = "bert-base-cased" # try it yourself
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
context = raw_datasets["train"][1]["context"]
question = raw_datasets["train"][1]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] What is in front of the Notre Dame Main Building? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

In [None]:
# what if the context is really long?
# split it into multiple samples
inputs = tokenizer(
  question,
  context,
  max_length=100,
  truncation="only_second",
  stride=50,
  return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
  print(tokenizer.decode(ids))

[CLS] What is in front of the Notre Dame Main Building? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the G [SEP]
[CLS] What is in front of the Notre Dame Main Building? [SEP] facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernade [SEP]
[CLS] What is in front of the Notre Dame Main Building? [SEP] of the Sacred Heart. Immediately behind the basilica is the Grotto, 

In [None]:
# what's the new key?
inputs['overflow_to_sample_mapping']

[0, 0, 0, 0]

In [None]:
inputs = tokenizer(
    raw_datasets["train"][:3]["question"],
    raw_datasets["train"][:3]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs['overflow_to_sample_mapping']

[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]

In [None]:
# it points to the original sample index
for ids in inputs["input_ids"]:
  print(tokenizer.decode(ids))

[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the B

In [None]:
# recreate inputs for just a single context-question pair
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [None]:
# what is this (weirdly named) offset_mapping?
# it tells us the location of each token
# notes:
# special tokens take up 0 space - (0, 0)
# the question portion is the same for each sample
# the context portion starting point inceases in each sample
inputs['offset_mapping']

[[(0, 0),
  (0, 4),
  (5, 7),
  (8, 10),
  (11, 16),
  (17, 19),
  (20, 23),
  (24, 29),
  (30, 34),
  (35, 39),
  (40, 48),
  (48, 49),
  (0, 0),
  (0, 13),
  (13, 15),
  (15, 16),
  (17, 20),
  (21, 27),
  (28, 31),
  (32, 33),
  (34, 42),
  (43, 52),
  (52, 53),
  (54, 56),
  (56, 58),
  (59, 62),
  (63, 67),
  (68, 76),
  (76, 77),
  (77, 78),
  (79, 83),
  (84, 88),
  (89, 91),
  (92, 93),
  (94, 100),
  (101, 107),
  (108, 110),
  (111, 114),
  (115, 121),
  (122, 126),
  (126, 127),
  (128, 139),
  (140, 142),
  (143, 148),
  (149, 151),
  (152, 155),
  (156, 160),
  (161, 169),
  (170, 173),
  (174, 180),
  (181, 183),
  (183, 184),
  (185, 187),
  (188, 189),
  (190, 196),
  (197, 203),
  (204, 206),
  (207, 213),
  (214, 218),
  (219, 223),
  (224, 226),
  (226, 229),
  (229, 232),
  (233, 237),
  (238, 241),
  (242, 248),
  (249, 250),
  (250, 251),
  (251, 254),
  (254, 256),
  (257, 259),
  (260, 262),
  (263, 264),
  (264, 265),
  (265, 268),
  (268, 269),
  (269, 270),
 

In [None]:
# problem: the position of the answer will change in each
# window of the context
# the answer is also the target for the neural network
# how can we recompute the targets for each context window?

# since we took the question and context from this sample earlier
answer = raw_datasets["train"][1]["answers"]
answer

{'text': ['a copper statue of Christ'], 'answer_start': [188]}

In [None]:
# find the start and end of the context (the first and last '1')
sequence_ids = inputs.sequence_ids(0)

ctx_start = sequence_ids.index(1) # first occurrence
ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1 # last occurrence

ctx_start, ctx_end

(13, 98)

In [None]:
inputs['offset_mapping']

[[(0, 0),
  (0, 4),
  (5, 7),
  (8, 10),
  (11, 16),
  (17, 19),
  (20, 23),
  (24, 29),
  (30, 34),
  (35, 39),
  (40, 48),
  (48, 49),
  (0, 0),
  (0, 13),
  (13, 15),
  (15, 16),
  (17, 20),
  (21, 27),
  (28, 31),
  (32, 33),
  (34, 42),
  (43, 52),
  (52, 53),
  (54, 56),
  (56, 58),
  (59, 62),
  (63, 67),
  (68, 76),
  (76, 77),
  (77, 78),
  (79, 83),
  (84, 88),
  (89, 91),
  (92, 93),
  (94, 100),
  (101, 107),
  (108, 110),
  (111, 114),
  (115, 121),
  (122, 126),
  (126, 127),
  (128, 139),
  (140, 142),
  (143, 148),
  (149, 151),
  (152, 155),
  (156, 160),
  (161, 169),
  (170, 173),
  (174, 180),
  (181, 183),
  (183, 184),
  (185, 187),
  (188, 189),
  (190, 196),
  (197, 203),
  (204, 206),
  (207, 213),
  (214, 218),
  (219, 223),
  (224, 226),
  (226, 229),
  (229, 232),
  (233, 237),
  (238, 241),
  (242, 248),
  (249, 250),
  (250, 251),
  (251, 254),
  (254, 256),
  (257, 259),
  (260, 262),
  (263, 264),
  (264, 265),
  (265, 268),
  (268, 269),
  (269, 270),
 

In [None]:
# check whether or not the answer is fully contained within the context
# if not, target is (start, end) = (0, 0)

ans_start_char = answer['answer_start'][0]
ans_end_char = ans_start_char + len(answer['text'][0])

offset = inputs['offset_mapping'][0]

start_idx = 0
end_idx = 0

if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
  print("target is (0, 0)")
  # nothing else to do
else:
  # find the start and end TOKEN positions

  # the 'trick' is knowing what is in units of tokens and what is in
  # units of characters

  # recall: the offset_mapping contains the character positions of each token

  i = ctx_start
  for start_end_char in offset[ctx_start:]:
    start, end = start_end_char
    if start == ans_start_char:
      start_idx = i
      # don't break yet
    
    if end == ans_end_char:
      end_idx = i
      break

    i += 1

start_idx, end_idx

(53, 57)

In [None]:
inputs['offset_mapping'][0]  # offset

[(0, 0),
 (0, 4),
 (5, 7),
 (8, 10),
 (11, 16),
 (17, 19),
 (20, 23),
 (24, 29),
 (30, 34),
 (35, 39),
 (40, 48),
 (48, 49),
 (0, 0),
 (0, 13),
 (13, 15),
 (15, 16),
 (17, 20),
 (21, 27),
 (28, 31),
 (32, 33),
 (34, 42),
 (43, 52),
 (52, 53),
 (54, 56),
 (56, 58),
 (59, 62),
 (63, 67),
 (68, 76),
 (76, 77),
 (77, 78),
 (79, 83),
 (84, 88),
 (89, 91),
 (92, 93),
 (94, 100),
 (101, 107),
 (108, 110),
 (111, 114),
 (115, 121),
 (122, 126),
 (126, 127),
 (128, 139),
 (140, 142),
 (143, 148),
 (149, 151),
 (152, 155),
 (156, 160),
 (161, 169),
 (170, 173),
 (174, 180),
 (181, 183),
 (183, 184),
 (185, 187),
 (188, 189),
 (190, 196),
 (197, 203),
 (204, 206),
 (207, 213),
 (214, 218),
 (219, 223),
 (224, 226),
 (226, 229),
 (229, 232),
 (233, 237),
 (238, 241),
 (242, 248),
 (249, 250),
 (250, 251),
 (251, 254),
 (254, 256),
 (257, 259),
 (260, 262),
 (263, 264),
 (264, 265),
 (265, 268),
 (268, 269),
 (269, 270),
 (271, 275),
 (276, 278),
 (279, 282),
 (283, 287),
 (288, 296),
 (297, 299),


In [None]:
# check
input_ids = inputs['input_ids'][0]
input_ids[start_idx : end_idx + 1]

[170, 7335, 5921, 1104, 4028]

In [None]:
tokenizer.decode(input_ids[start_idx : end_idx + 1])

'a copper statue of Christ'

In [None]:
inputs['offset_mapping'][0][ctx_start][0]

0

In [None]:
def find_answer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset):
  
  start_idx = 0
  end_idx = 0

  if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
    pass
    # print("target is (0, 0)")
    # nothing else to do
  else:
    # find the start and end TOKEN positions

    # the 'trick' is knowing what is in units of tokens and what is in
    # units of characters

    # recall: the offset_mapping contains the character positions of each token

    i = ctx_start
    for start_end_char in offset[ctx_start:]:
      start, end = start_end_char
      if start == ans_start_char:
        start_idx = i
        # don't break yet
      
      if end == ans_end_char:
        end_idx = i
        break

      i += 1
  return start_idx, end_idx

In [None]:
inputs['offset_mapping'][0], inputs['offset_mapping'][1], inputs['offset_mapping'][2], inputs['offset_mapping'][3]

([(0, 0),
  (0, 4),
  (5, 7),
  (8, 10),
  (11, 16),
  (17, 19),
  (20, 23),
  (24, 29),
  (30, 34),
  (35, 39),
  (40, 48),
  (48, 49),
  (0, 0),
  (0, 13),
  (13, 15),
  (15, 16),
  (17, 20),
  (21, 27),
  (28, 31),
  (32, 33),
  (34, 42),
  (43, 52),
  (52, 53),
  (54, 56),
  (56, 58),
  (59, 62),
  (63, 67),
  (68, 76),
  (76, 77),
  (77, 78),
  (79, 83),
  (84, 88),
  (89, 91),
  (92, 93),
  (94, 100),
  (101, 107),
  (108, 110),
  (111, 114),
  (115, 121),
  (122, 126),
  (126, 127),
  (128, 139),
  (140, 142),
  (143, 148),
  (149, 151),
  (152, 155),
  (156, 160),
  (161, 169),
  (170, 173),
  (174, 180),
  (181, 183),
  (183, 184),
  (185, 187),
  (188, 189),
  (190, 196),
  (197, 203),
  (204, 206),
  (207, 213),
  (214, 218),
  (219, 223),
  (224, 226),
  (226, 229),
  (229, 232),
  (233, 237),
  (238, 241),
  (242, 248),
  (249, 250),
  (250, 251),
  (251, 254),
  (254, 256),
  (257, 259),
  (260, 262),
  (263, 264),
  (264, 265),
  (265, 268),
  (268, 269),
  (269, 270),
 

In [None]:
# try it on all context windows
# sometimes, the answer won't appear!

start_idxs = []
end_idxs = []

for i, offset in enumerate(inputs['offset_mapping']):
  # the final window may not be full size - can't assume 100
  sequence_ids = inputs.sequence_ids(i)

  # find start + end of context (first 1 and last 1)
  ctx_start = sequence_ids.index(1)
  ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

  start_idx, end_idx = find_answer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset)

  start_idxs.append(start_idx)
  end_idxs.append(end_idx)

start_idxs, end_idxs

([53, 17, 0, 0], [57, 21, 0, 0])

In [None]:
# some questions have leading and/or trailing whitespace
for q in raw_datasets["train"]["question"][:1000]:
  if q.strip() != q:
    print(q)
    print(raw_datasets["train"]["question"][:1000].index(q))

In what city and state did Beyonce  grow up? 
272
 The album, Dangerously in Love  achieved what spot on the Billboard Top 100 chart?
406
Which song did Beyonce sing at the first couple's inaugural ball? 
463
What event did Beyoncé perform at one month after Obama's inauguration? 
548
Where was the album released? 
565
What movie influenced Beyonce towards empowerment themes? 
742


In [None]:
len(inputs['attention_mask'][1])

100

In [None]:
len(inputs['attention_mask'])

4

In [None]:
inputs['attention_mask']

[[1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1]

In [None]:
raw_datasets["train"][1]

{'id': '5733be284776f4190066117f',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'What is in front of the Notre Dame Main Building?',
 'answers': {'text': ['a copper statue of Christ'], 'answer_start': [188]}}

In [None]:
inputs["offset_mapping"]

[[(0, 0),
  (0, 4),
  (5, 7),
  (8, 10),
  (11, 16),
  (17, 19),
  (20, 23),
  (24, 29),
  (30, 34),
  (35, 39),
  (40, 48),
  (48, 49),
  (0, 0),
  (0, 13),
  (13, 15),
  (15, 16),
  (17, 20),
  (21, 27),
  (28, 31),
  (32, 33),
  (34, 42),
  (43, 52),
  (52, 53),
  (54, 56),
  (56, 58),
  (59, 62),
  (63, 67),
  (68, 76),
  (76, 77),
  (77, 78),
  (79, 83),
  (84, 88),
  (89, 91),
  (92, 93),
  (94, 100),
  (101, 107),
  (108, 110),
  (111, 114),
  (115, 121),
  (122, 126),
  (126, 127),
  (128, 139),
  (140, 142),
  (143, 148),
  (149, 151),
  (152, 155),
  (156, 160),
  (161, 169),
  (170, 173),
  (174, 180),
  (181, 183),
  (183, 184),
  (185, 187),
  (188, 189),
  (190, 196),
  (197, 203),
  (204, 206),
  (207, 213),
  (214, 218),
  (219, 223),
  (224, 226),
  (226, 229),
  (229, 232),
  (233, 237),
  (238, 241),
  (242, 248),
  (249, 250),
  (250, 251),
  (251, 254),
  (254, 256),
  (257, 259),
  (260, 262),
  (263, 264),
  (264, 265),
  (265, 268),
  (268, 269),
  (269, 270),
 

In [None]:
inputs  # fruits = ['apple', 'banana', 'cherry']

{'input_ids': [[101, 1327, 1110, 1107, 1524, 1104, 1103, 10360, 8022, 4304, 4334, 136, 102, 22182, 1193, 117, 1103, 1278, 1144, 170, 2336, 1959, 119, 1335, 4184, 1103, 4304, 4334, 112, 188, 2284, 10945, 1110, 170, 5404, 5921, 1104, 1103, 6567, 2090, 119, 13301, 1107, 1524, 1104, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171, 17506, 9538, 1110, 1103, 144, 102], [101, 1327, 1110, 1107, 1524, 1104, 1103, 10360, 8022, 4304, 4334, 136, 102, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171, 17506, 9538, 1110, 1103, 144, 10595, 2430, 117, 1

In [None]:
offset_mapping = inputs.pop("offset_mapping")  # fruits.pop(1) = x

In [None]:
print(inputs)  # print(fruits)

{'input_ids': [[101, 1327, 1110, 1107, 1524, 1104, 1103, 10360, 8022, 4304, 4334, 136, 102, 22182, 1193, 117, 1103, 1278, 1144, 170, 2336, 1959, 119, 1335, 4184, 1103, 4304, 4334, 112, 188, 2284, 10945, 1110, 170, 5404, 5921, 1104, 1103, 6567, 2090, 119, 13301, 1107, 1524, 1104, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171, 17506, 9538, 1110, 1103, 144, 102], [101, 1327, 1110, 1107, 1524, 1104, 1103, 10360, 8022, 4304, 4334, 136, 102, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171, 17506, 9538, 1110, 1103, 144, 10595, 2430, 117, 1

In [None]:
# now we are ready to process (tokenize) the training data
# (i.e. expand question+context pairs into question+smaller context windows)

# Google used 384 for SQuAD
max_length = 384
stride = 128

def tokenize_fn_train(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  offset_mapping = inputs.pop("offset_mapping")  # We do not understant why do not define????????????????????
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
  answers = batch['answers']
  start_idxs, end_idxs = [], []

  # same loop as above
  for i, offset in enumerate(offset_mapping):
    sample_idx = orig_sample_idxs[i]
    answer = answers[sample_idx]

    ans_start_char = answer['answer_start'][0]
    ans_end_char = ans_start_char + len(answer['text'][0])

    sequence_ids = inputs.sequence_ids(i)

    # find start + end of context (first 1 and last 1)
    ctx_start = sequence_ids.index(1)
    ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

    start_idx, end_idx = find_answer_token_idx(
      ctx_start,
      ctx_end,
      ans_start_char,
      ans_end_char,
      offset)

    start_idxs.append(start_idx)
    end_idxs.append(end_idx)
  
  inputs["start_positions"] = start_idxs
  inputs["end_positions"] = end_idxs
  return inputs

In [None]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'overflow_to_sample_mapping'])

In [None]:
raw_datasets["train"].column_names

['id', 'title', 'context', 'question', 'answers']

In [None]:
train_dataset = raw_datasets["train"].map(
  tokenize_fn_train,
  batched=True,
  remove_columns=raw_datasets["train"].column_names,
)
len(raw_datasets["train"]), len(train_dataset)

(87599, 88729)

In [None]:
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 88729
})

In [None]:
# note: we'll keep these IDs for later
raw_datasets["validation"][0]

{'id': '56be4db0acb8001400a502ec',
 'title': 'Super_Bowl_50',
 'context': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.',
 'question': 'Which NFL team represented the AFC at Super Bowl 50?',
 'answers': {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],


In [None]:
# tokenize the validation set differently
# we won't need the targets since we will just compare with the original answer
# also: overwrite offset_mapping with Nones in place of question
def tokenize_fn_validation(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")

  sample_ids = []

  # rewrite offset mapping by replacing question tuples with None
  # this will be helpful later on when we compute metrics
  for i in range(len(inputs["input_ids"])):  # i : 0, 1, 2, 3
    sample_idx = orig_sample_idxs[i]
    sample_ids.append(batch['id'][sample_idx])

    sequence_ids = inputs.sequence_ids(i)
    offset = inputs["offset_mapping"][i]
    inputs["offset_mapping"][i] = [
      x if sequence_ids[j] == 1 else None for j, x in enumerate(offset)]
    
  inputs['sample_id'] = sample_ids
  return inputs

In [None]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'overflow_to_sample_mapping'])

In [None]:
# offset = inputs['offset_mapping'][0]

In [None]:
# inputs['offset_mapping'][0]

In [None]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'overflow_to_sample_mapping'])

In [None]:
inputs['overflow_to_sample_mapping']

[0, 0, 0, 0]

In [None]:
# orig_sample_idxs = [0, 0, 0, 0]

In [None]:
len(inputs["input_ids"])

4

In [None]:
validation_dataset = raw_datasets["validation"].map(
  tokenize_fn_validation,
  batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)
len(raw_datasets["validation"]), len(validation_dataset)

(10570, 10822)

In [None]:
# Metrics
from datasets import load_metric

metric = load_metric("squad")

  metric = load_metric("squad")


In [None]:
predicted_answers = [
  {'id': '1', 'prediction_text': 'Albert Einstein'},
  {'id': '2', 'prediction_text': 'physicist'},
  {'id': '3', 'prediction_text': 'general relativity'},
]
true_answers = [
  {'id': '1', 'answers': {'text': ['Albert Einstein'], 'answer_start': [100]}},
  {'id': '2', 'answers': {'text': ['physicist'], 'answer_start': [100]}},
  {'id': '3', 'answers': {'text': ['special relativity'], 'answer_start': [100]}},
]

# id and answer_start seem superfluous but you'll get an error if not included
# exercise: remove them (one at a time) and see!
metric.compute(predictions=predicted_answers, references=true_answers)

{'exact_match': 66.66666666666667, 'f1': 83.33333333333333}

In [None]:
# next problem: how to go from logits to prediction text?
# to make it easier, let's work on an already-trained question-answering model
small_validation_dataset = raw_datasets["validation"].select(range(100))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer2 = AutoTokenizer.from_pretrained(trained_checkpoint)

# temporarily assign tokenizer2 to tokenizer since it's used as a global
# in tokenize_fn_validation
old_tokenizer = tokenizer
tokenizer = tokenizer2

small_validation_processed = small_validation_dataset.map(
    tokenize_fn_validation,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)


len(small_validation_dataset), len(small_validation_processed)
# change it back
tokenizer = old_tokenizer

In [None]:
small_validation_processed.features

{'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'offset_mapping': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'sample_id': Value(dtype='string', id=None)}

In [None]:
len(small_validation_dataset), len(small_validation_processed)

(100, 100)

In [None]:
raw_datasets["validation"][20]['title']

'Super_Bowl_50'

In [None]:
# get the model outputs
import torch
from transformers import AutoModelForQuestionAnswering

# the trained model doesn't use these columns
small_model_inputs = small_validation_processed.remove_columns(
  ["sample_id", "offset_mapping"])
small_model_inputs.set_format("torch")

# get gpu device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# move tensors to gpu device
small_model_inputs_gpu = {
  k: small_model_inputs[k].to(device) for k in small_model_inputs.column_names
}

# download the model
trained_model = AutoModelForQuestionAnswering.from_pretrained(
  trained_checkpoint).to(device)

# get the model outputs
with torch.no_grad():
  outputs = trained_model(**small_model_inputs_gpu)

In [None]:
outputs

QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[ -2.2607,  -5.1783,  -5.2709,  ...,  -9.5243,  -9.5183,  -9.5288],
        [ -2.5961,  -5.5482,  -5.5313,  ...,  -9.9598,  -9.9533,  -9.9860],
        [ -3.7127,  -7.1848,  -8.5388,  ..., -11.6557, -11.6571, -11.6505],
        ...,
        [ -2.0260,  -4.4167,  -4.4980,  ...,  -8.1479,  -8.1530,  -8.1760],
        [ -4.1553,  -5.8304,  -7.1643,  ..., -10.5255, -10.5251, -10.4890],
        [ -3.2000,  -5.8162,  -6.7249,  ...,  -9.4935,  -9.5038,  -9.4871]]), end_logits=tensor([[ -0.7353,  -4.9236,  -5.1048,  ...,  -8.8734,  -8.8915,  -8.8550],
        [ -1.3056,  -5.3870,  -5.4945,  ...,  -9.4895,  -9.5039,  -9.4958],
        [ -2.7649,  -7.2201,  -9.0916,  ..., -11.3106, -11.3414, -11.2702],
        ...,
        [ -0.0768,  -4.8210,  -4.4374,  ...,  -8.0483,  -8.0502,  -7.9903],
        [ -2.7347,  -5.3650,  -7.2549,  ..., -10.0498, -10.0661,  -9.9886],
        [ -1.0991,  -4.2569,  -6.1267,  ...,  -8.6882,  -8.6889,  -8.627

In [None]:
outputs["start_logits"].shape

torch.Size([100, 384])

In [None]:
outputs["start_logits"]

tensor([[ -2.2607,  -5.1783,  -5.2709,  ...,  -9.5243,  -9.5183,  -9.5288],
        [ -2.5961,  -5.5482,  -5.5313,  ...,  -9.9598,  -9.9533,  -9.9860],
        [ -3.7127,  -7.1848,  -8.5388,  ..., -11.6557, -11.6571, -11.6505],
        ...,
        [ -2.0260,  -4.4167,  -4.4980,  ...,  -8.1479,  -8.1530,  -8.1760],
        [ -4.1553,  -5.8304,  -7.1643,  ..., -10.5255, -10.5251, -10.4890],
        [ -3.2000,  -5.8162,  -6.7249,  ...,  -9.4935,  -9.5038,  -9.4871]])

In [None]:
outputs["end_logits"].shape

torch.Size([100, 384])

In [None]:
type(outputs["end_logits"])

torch.Tensor

In [None]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

In [None]:
start_logits

array([[ -2.2607286,  -5.178324 ,  -5.270891 , ...,  -9.524325 ,
         -9.518305 ,  -9.528757 ],
       [ -2.5960746,  -5.5482106,  -5.531334 , ...,  -9.959751 ,
         -9.953276 ,  -9.986047 ],
       [ -3.7127328,  -7.184835 ,  -8.538828 , ..., -11.655701 ,
        -11.6571455, -11.650536 ],
       ...,
       [ -2.026028 ,  -4.4166594,  -4.4980016, ...,  -8.147888 ,
         -8.153048 ,  -8.175961 ],
       [ -4.155297 ,  -5.83042  ,  -7.164263 , ..., -10.525518 ,
        -10.525085 , -10.489031 ],
       [ -3.2000246,  -5.816207 ,  -6.7249413, ...,  -9.493472 ,
         -9.503824 ,  -9.487104 ]], dtype=float32)

In [None]:
small_validation_processed['sample_id'][:5]

['56be4db0acb8001400a502ec',
 '56be4db0acb8001400a502ed',
 '56be4db0acb8001400a502ee',
 '56be4db0acb8001400a502ef',
 '56be4db0acb8001400a502f0']

In [None]:
validation_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'offset_mapping', 'sample_id'],
    num_rows: 10822
})

In [None]:
len(validation_dataset['sample_id'])

10822

In [None]:
validation_dataset['sample_id'][10:30]

['56bea9923aeaaa14008c91bb',
 '56beace93aeaaa14008c91df',
 '56beace93aeaaa14008c91e0',
 '56beace93aeaaa14008c91e1',
 '56beace93aeaaa14008c91e2',
 '56beace93aeaaa14008c91e3',
 '56bf10f43aeaaa14008c94fd',
 '56bf10f43aeaaa14008c94fe',
 '56bf10f43aeaaa14008c94ff',
 '56bf10f43aeaaa14008c9500',
 '56bf10f43aeaaa14008c9501',
 '56d20362e7d4791d009025e8',
 '56d20362e7d4791d009025e9',
 '56d20362e7d4791d009025ea',
 '56d20362e7d4791d009025eb',
 '56d600e31c85041400946eae',
 '56d600e31c85041400946eb0',
 '56d600e31c85041400946eb1',
 '56d9895ddc89441400fdb50e',
 '56d9895ddc89441400fdb510']

In [None]:
validation_dataset['sample_id'][1]

'56be4db0acb8001400a502ed'

In [None]:
validation_dataset['sample_id'][0]

'56be4db0acb8001400a502ec'

In [None]:
len(set(validation_dataset['sample_id']))

10570

In [94]:
validation_dataset['sample_id'][10:15]

['56bea9923aeaaa14008c91bb',
 '56beace93aeaaa14008c91df',
 '56beace93aeaaa14008c91e0',
 '56beace93aeaaa14008c91e1',
 '56beace93aeaaa14008c91e2']

In [96]:
small_validation_processed['sample_id'][10:15]


['56bea9923aeaaa14008c91bb',
 '56beace93aeaaa14008c91df',
 '56beace93aeaaa14008c91e0',
 '56beace93aeaaa14008c91e1',
 '56beace93aeaaa14008c91e2']

In [97]:
for i, j in enumerate(small_validation_processed['sample_id'][10:15]):
    print(i, j)

0 56bea9923aeaaa14008c91bb
1 56beace93aeaaa14008c91df
2 56beace93aeaaa14008c91e0
3 56beace93aeaaa14008c91e1
4 56beace93aeaaa14008c91e2


In [None]:
# sample_id2idxs = {key: value}  sample_id2idxs[keys] = value

In [99]:
sample_id2idxs = {}  # {"56bea9923aeaaa14008c91bb": [0, 10400], "56beace93aeaaa14008c91df": [1]}
for i, id_ in enumerate(small_validation_processed['sample_id']):
  if id_ not in sample_id2idxs:
    sample_id2idxs[id_] = [i]
  else:
    print("here")
    sample_id2idxs[id_].append(i)

In [None]:
# example: {'56be4db0acb8001400a502ec': [0, 1, 2, 3], ...}
sample_id2idxs = {}
for i, id_ in enumerate(small_validation_processed['sample_id']):
  if id_ not in sample_id2idxs:
    sample_id2idxs[id_] = [i]
  else:
    print("here")
    sample_id2idxs[id_].append(i)

In [None]:
start_logits.shape, end_logits.shape

((100, 384), (100, 384))

In [None]:
# reminder of how to find indices with the largest values
(-start_logits[0]).argsort()

array([ 46,  57,  47,  38,  39,  58,  50,  43,  45,  54,  56,  49,  13,
        42,  40,  35,  27,  31,  48,  41,  53,  44,  37,  59,  78,  15,
         0,  52,  24,  65,  81,  70,  18,  51,  55,  26,  69,  29,  28,
        75,  61,  64,  23,  36,  32,  11, 101,  62,  66,  34,  95,  30,
        63,  21,  19,  20,  17,  14,  22,  33,  68,  87, 171,  12,  76,
        71,  73,  92, 110,  84, 151,   1,  74,   2,   6,  16,  80,  79,
       105,  98,  10,  96, 136, 169, 106, 100,  93, 165,  67, 109,   8,
        90,   3, 115,  60,   5,  97,   7, 103, 102,  86,  72, 111,  89,
       108,   4,  88,  25, 132,  77, 123, 150, 124, 153,  83, 118,  82,
        85, 107, 114, 143, 164, 137, 130, 166, 159, 131,  91,   9, 144,
       139, 160,  94, 141, 128, 112, 134, 152, 170, 154, 117, 127, 104,
       140, 157, 155, 133, 145, 119, 162, 138, 135, 156, 167, 168, 126,
       148, 163, 161, 116,  99, 120, 142, 158, 125, 146, 113, 121, 147,
       149, 129, 122, 311, 312, 304, 309, 313, 310, 300, 307, 31

In [None]:
start_logits[0][(-start_logits[0]).argsort()]

array([10.694451  ,  9.803686  ,  4.4599786 ,  4.400489  ,  2.9437866 ,
        2.7017422 ,  2.012652  ,  1.5780808 ,  0.52237856,  0.02074176,
       -0.02801922, -0.04970397, -0.38572356, -0.6945309 , -0.7979434 ,
       -0.86779773, -0.8722001 , -1.3516843 , -1.3703636 , -1.3878787 ,
       -1.5135026 , -1.7355411 , -1.8827026 , -1.8932847 , -1.9078901 ,
       -1.9304947 , -2.2607286 , -2.2983828 , -2.3069277 , -2.5027347 ,
       -2.5100536 , -2.5308342 , -2.5399885 , -2.6718073 , -2.7323494 ,
       -2.7710114 , -2.7713616 , -2.952129  , -3.0604606 , -3.170597  ,
       -3.2045417 , -3.569333  , -3.5797982 , -3.6668777 , -3.7250547 ,
       -3.7498534 , -3.7632098 , -3.9968114 , -4.011324  , -4.0687957 ,
       -4.0944786 , -4.19547   , -4.238307  , -4.332359  , -4.35241   ,
       -4.3879576 , -4.388604  , -4.396608  , -4.6790495 , -4.7030234 ,
       -4.77575   , -4.777807  , -4.7882137 , -4.7882376 , -4.822118  ,
       -4.8725348 , -4.884929  , -4.898141  , -5.072089  , -5.10

In [None]:
# reminder: in offset_mapping we store None everywhere except the context window
# in the context window we store tuples for each token containing:
# (start_character_position, end_character_position)
small_validation_processed['offset_mapping'][0]

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 [0, 5],
 [6, 10],
 [11, 13],
 [14, 17],
 [18, 20],
 [21, 29],
 [30, 38],
 [39, 43],
 [44, 46],
 [47, 56],
 [57, 60],
 [61, 69],
 [70, 72],
 [73, 76],
 [77, 85],
 [86, 94],
 [95, 101],
 [102, 103],
 [103, 106],
 [106, 107],
 [108, 111],
 [112, 115],
 [116, 120],
 [121, 127],
 [127, 128],
 [129, 132],
 [133, 141],
 [142, 150],
 [151, 161],
 [162, 163],
 [163, 166],
 [166, 167],
 [168, 176],
 [177, 183],
 [184, 191],
 [192, 200],
 [201, 204],
 [205, 213],
 [214, 222],
 [223, 233],
 [234, 235],
 [235, 238],
 [238, 239],
 [240, 248],
 [249, 257],
 [258, 266],
 [267, 269],
 [269, 270],
 [270, 272],
 [273, 275],
 [276, 280],
 [281, 286],
 [287, 292],
 [293, 298],
 [299, 303],
 [304, 309],
 [309, 310],
 [311, 314],
 [315, 319],
 [320, 323],
 [324, 330],
 [331, 333],
 [334, 342],
 [343, 344],
 [344, 345],
 [346, 350],
 [350, 351],
 [352, 354],
 [355, 359],
 [359, 360],
 [360, 361],
 [362, 369],
 [370, 37

In [None]:
n_largest = 20
max_answer_length = 30
predicted_answers = []

# we are looping through the original (untokenized) dataset
# because we need to grab the answer from the original string context
for sample in small_validation_dataset:
  sample_id = sample['id']
  context = sample['context']

  # update these as we loop through candidate answers
  best_score = float('-inf')
  best_answer = None

  # now loop through the *expanded* input samples (fixed size context windows)
  # from here we will pick the highest probability start/end combination
  for idx in sample_id2idxs[sample_id]:
    start_logit = start_logits[idx] # (384,) vector
    end_logit = end_logits[idx] # (384,) vector
    offsets = small_validation_processed[idx]['offset_mapping']

    start_indices = (-start_logit).argsort()
    end_indices = (-end_logit).argsort()

    for start_idx in start_indices[:n_largest]:
      for end_idx in end_indices[:n_largest]:

        # skip answers not contained in context window
        # recall: we set entries not pertaining to context to None earlier
        if offsets[start_idx] is None or offsets[end_idx] is None:
          continue
        
        # skip answers where end < start
        if end_idx < start_idx:
          continue
        
        # skip answers that are too long
        if end_idx - start_idx + 1 > max_answer_length:
          continue
        
        # see theory lecture for score calculation
        score = start_logit[start_idx] + end_logit[end_idx]
        if score > best_score:
          best_score = score

          # find positions of start and end characters
          # recall: offsets contains tuples for each token:
          # (start_char, end_char)
          first_ch = offsets[start_idx][0]
          last_ch = offsets[end_idx][1]

          best_answer = context[first_ch:last_ch]

  # save best answer
  predicted_answers.append({'id': sample_id, 'prediction_text': best_answer})

In [None]:
small_validation_dataset['answers'][0]

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 'answer_start': [177, 177, 177]}

In [None]:
# now test it!

true_answers = [
  {'id': x['id'], 'answers': x['answers']} for x in small_validation_dataset
]
metric.compute(predictions=predicted_answers, references=true_answers)

{'exact_match': 83.0, 'f1': 88.25000000000004}

In [None]:
# now let's define a full compute_metrics function
# note: this will NOT be called from the trainer

from tqdm.autonotebook import tqdm

def compute_metrics(start_logits, end_logits, processed_dataset, orig_dataset):
  # map sample_id ('56be4db0acb8001400a502ec') to row indices of processed data
  sample_id2idxs = {}
  for i, id_ in enumerate(processed_dataset['sample_id']):
    if id_ not in sample_id2idxs:
      sample_id2idxs[id_] = [i]
    else:
      sample_id2idxs[id_].append(i)

  predicted_answers = []
  for sample in tqdm(orig_dataset):

    sample_id = sample['id']
    context = sample['context']

    # update these as we loop through candidate answers
    best_score = float('-inf')
    best_answer = None

    # now loop through the *expanded* input samples (fixed size context windows)
    # from here we will pick the highest probability start/end combination
    for idx in sample_id2idxs[sample_id]:
      start_logit = start_logits[idx] # (T,) vector
      end_logit = end_logits[idx] # (T,) vector

      # note: do NOT do the reverse: ['offset_mapping'][idx]
      offsets = processed_dataset[idx]['offset_mapping']

      start_indices = (-start_logit).argsort()
      end_indices = (-end_logit).argsort()

      for start_idx in start_indices[:n_largest]:
        for end_idx in end_indices[:n_largest]:

          # skip answers not contained in context window
          # recall: we set entries not pertaining to context to None earlier
          if offsets[start_idx] is None or offsets[end_idx] is None:
            continue
          
          # skip answers where end < start
          if end_idx < start_idx:
            continue
          
          # skip answers that are too long
          if end_idx - start_idx + 1 > max_answer_length:
            continue
          
          # see theory lecture for score calculation
          score = start_logit[start_idx] + end_logit[end_idx]
          if score > best_score:
            best_score = score

            # find positions of start and end characters
            # recall: offsets contains tuples for each token:
            # (start_char, end_char)
            first_ch = offsets[start_idx][0]
            last_ch = offsets[end_idx][1]

            best_answer = context[first_ch:last_ch]

    # save best answer
    predicted_answers.append({'id': sample_id, 'prediction_text': best_answer})
  
  # compute the metrics
  true_answers = [
    {'id': x['id'], 'answers': x['answers']} for x in orig_dataset
  ]
  return metric.compute(predictions=predicted_answers, references=true_answers)

In [None]:
# run our function on the same mini dataset as before
compute_metrics(
    start_logits,
    end_logits,
    small_validation_processed,
    small_validation_dataset,
)

100%|██████████| 100/100 [00:00<00:00, 601.87it/s]




{'exact_match': 83.0, 'f1': 88.25000000000004}

In [None]:
# now load the model we want to fine-tune
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model

DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
      

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    "finetuned-squad",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
)

ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install transformers[torch]` or `pip install accelerate -U`