In [33]:
!pip install transformers datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [34]:
from datasets import load_dataset
raw_datasets = load_dataset("squad")
raw_datasets



  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [35]:
raw_datasets["train"][1]["title"]

'University_of_Notre_Dame'

In [36]:
raw_datasets["train"][1]["context"]

'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.'

In [37]:
raw_datasets["train"][1]["question"]

'What is in front of the Notre Dame Main Building?'

In [38]:
raw_datasets["train"][1]["answers"]

{'text': ['a copper statue of Christ'], 'answer_start': [188]}

In [39]:
# for train set, ensure that there'e always 1 answer
# not multiple answers, or no answers
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)



Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

In [40]:
# for validation set, there may be multiple answers
raw_datasets["validation"][2]["answers"]

{'text': ['Santa Clara, California',
  "Levi's Stadium",
  "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."],
 'answer_start': [403, 355, 355]}

In [41]:
# why are there multiple answers?
raw_datasets["validation"][2]["context"]

'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.'

In [42]:
raw_datasets["validation"][2]["question"]

'Where did Super Bowl 50 take place?'

In [43]:
# they may even be the same!
raw_datasets["validation"][0]["answers"]

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 'answer_start': [177, 177, 177]}

In [44]:
from transformers import AutoTokenizer

model_checkpoint = "distilbert-base-cased"
# model_checkpoint = "bert-base-cased" # try it out
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [45]:
# test the tokenizer on a context-question pair
context = raw_datasets["train"][1]["context"]
question = raw_datasets["train"][1]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] What is in front of the Notre Dame Main Building? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

In [46]:
# what if the context is really long?
# split it into multiple samples
inut = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second", # truncates the context, not the question
    stride=50, # to not cut off the question
    return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
  print(tokenizer.decode(ids))

[CLS]
What
is
in
front
of
the
Notre
Dame
Main
Building
?
[SEP]
Architectural
##ly
,
the
school
has
a
Catholic
character
.
At
##op
the
Main
Building
'
s
gold
dome
is
a
golden
statue
of
the
Virgin
Mary
.
Immediately
in
front
of
the
Main
Building
and
facing
it
,
is
a
copper
statue
of
Christ
with
arms
up
##rai
##sed
with
the
legend
"
V
##eni
##te
Ad
Me
O
##m
##nes
"
.
Next
to
the
Main
Building
is
the
Basilica
of
the
Sacred
Heart
.
Immediately
behind
the
b
##asi
##lica
is
the
G
##rot
##to
,
a
Marian
place
of
prayer
and
reflection
.
It
is
a
replica
of
the
g
##rot
##to
at
Lou
##rdes
,
France
where
the
Virgin
Mary
reputed
##ly
appeared
to
Saint
Bern
##ade
##tte
So
##ubi
##rous
in
1858
.
At
the
end
of
the
main
drive
(
and
in
a
direct
line
that
connects
through
3
statues
and
the
Gold
Dome
)
,
is
a
simple
,
modern
stone
statue
of
Mary
.
[SEP]


In [47]:
# check the keys of the result
inputs.keys()

dict_keys(['input_ids', 'attention_mask'])

In [48]:
# wha's the new key?
inputs['attention_mask']

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1]

In [49]:
# 
inputs = tokenizer(
    raw_datasets["train"][:3]["question"],
    raw_datasets["train"][:3]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

In [50]:
#5:38
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [51]:
# what is this offset_mapping?
# it tells us the location of each token
# notes:
# special tokens take up 0 space (0, 0)
# the question portion is the same for each sample "what is in" = (what start at 0, 4 char long etc..)
# the context portion starting point increases in each sample
inputs["offset_mapping"]

[[(0, 0),
  (0, 4),
  (5, 7),
  (8, 10),
  (11, 16),
  (17, 19),
  (20, 23),
  (24, 29),
  (30, 34),
  (35, 39),
  (40, 48),
  (48, 49),
  (0, 0),
  (0, 13),
  (13, 15),
  (15, 16),
  (17, 20),
  (21, 27),
  (28, 31),
  (32, 33),
  (34, 42),
  (43, 52),
  (52, 53),
  (54, 56),
  (56, 58),
  (59, 62),
  (63, 67),
  (68, 76),
  (76, 77),
  (77, 78),
  (79, 83),
  (84, 88),
  (89, 91),
  (92, 93),
  (94, 100),
  (101, 107),
  (108, 110),
  (111, 114),
  (115, 121),
  (122, 126),
  (126, 127),
  (128, 139),
  (140, 142),
  (143, 148),
  (149, 151),
  (152, 155),
  (156, 160),
  (161, 169),
  (170, 173),
  (174, 180),
  (181, 183),
  (183, 184),
  (185, 187),
  (188, 189),
  (190, 196),
  (197, 203),
  (204, 206),
  (207, 213),
  (214, 218),
  (219, 223),
  (224, 226),
  (226, 229),
  (229, 232),
  (233, 237),
  (238, 241),
  (242, 248),
  (249, 250),
  (250, 251),
  (251, 254),
  (254, 256),
  (257, 259),
  (260, 262),
  (263, 264),
  (264, 265),
  (265, 268),
  (268, 269),
  (269, 270),
 

In [52]:
# check size of list
len(inputs['offset_mapping'])

4

In [53]:
# check the len of the elements inside the list
len(inputs['offset_mapping'][0])

100

In [54]:
# 
print(inputs.sequence_ids(0))

[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]


In [55]:
# problem: the position of the answer will change in each
# window of the context
# the answer is also the target for the neural network
# how can we recompute the targets for each context window?

# snce we took the question and context from this sample earlier
answer = raw_datasets["train"][1]["answers"]
answer

{'text': ['a copper statue of Christ'], 'answer_start': [188]}

In [56]:
# check the type
type(inputs.sequence_ids(0))

list

In [57]:
# find the start and end of the context (the first and last '1')
sequence_ids = inputs.sequence_ids(0)

ctx_start = sequence_ids.index(1) # first occurence
ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) -1 # last occurence

ctx_start, ctx_end

(13, 98)

In [58]:
# check whether or not the answer is fully contained within the context
# if not, target is (start, end) = (0, 0)

ans_start_char = answer['answer_start'][0]
ans_end_char = ans_start_char + len(answer['text'][0])

offset = inputs['offset_mapping'][0]

start_idx = 0
end_idx = 0

if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
  print("target is (0, 0)")
  # nothing else to do
else:
  # find the start and end TOKEN positions
  # the trick is knowing what is in units of tokens and what is in units of charachters

  # recall: the offset_mapping contains the charachter positions of each token

  i = ctx_start
  for start_end_char in offset[ctx_start:]:
    start, end = start_end_char
    if start == ans_start_char:
      start_idx = i
      # don't break yet

    if end == ans_end_char:
      end_idx = i
      break

    i += 1

start_idx, end_idx

(53, 57)

In [59]:
# check
input_ids = inputs['input_ids'][0]
input_ids[start_idx : end_idx + 1]

[170, 7335, 5921, 1104, 4028]

In [60]:
# convert the numbers above into text
tokenizer.decode(input_ids[start_idx : end_idx + 1])

'a copper statue of Christ'

In [61]:
def find_answer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset):

  start_idx = 0
  end_idx = 0

  if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
    pass
    # nothing else to do
  else:
    # find the start and end TOKEN positions
    # the trick is knowing what is in units of tokens and what is in units of charachters

    # recall: the offset_mapping contains the charachter positions of each token
    i = ctx_start
    for start_end_char in offset[ctx_start:]:
      start, end = start_end_char
      if start == ans_start_char:
        start_idx = i
      # don't break yet

      if end == ans_end_char:
        end_idx = i
        break

      i += 1

  return start_idx, end_idx

In [62]:
# try it on all context windows
# sometimes, the answer won't apear!

start_idxs = []
end_idxs = []

for i, offset in enumerate(inputs['offset_mapping']):
  # the final window may not be full size - cant't assume 100
  sequence_ids = inputs.sequence_ids(i)

  # find start + end of context (first 1 and last 1)
  ctx_start = sequence_ids.index(1)
  ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) -1

  start_idx, end_idx = find_answer_token_idx(
      ctx_start,
      ctx_end,
      ans_start_char,
      ans_end_char,
      offset
  )

  start_idxs.append(start_idx)
  end_idxs.append(end_idx)

start_idxs, end_idxs

([53, 17, 0, 0], [57, 21, 0, 0])

In [None]:
# some questions have leading and/or trailing whitespace
for q in raw_datasets["train"]["queston"][:1000]:
  if q.strip() != q:
    print(q)