In [5]:
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

In [4]:
PRE_TRAINED_MODEL_NAME = "bert-large-uncased-whole-word-masking-finetuned-squad"
model = BertForQuestionAnswering.from_pretrained(PRE_TRAINED_MODEL_NAME)

config.json: 100%|██████████| 443/443 [00:00<?, ?B/s] 
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
model.safetensors: 100%|██████████| 1.34G/1.34G [00:59<00:00, 22.4MB/s]
Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a

In [6]:
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

tokenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<?, ?B/s]
vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 513kB/s]
tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 4.52MB/s]


In [10]:
question = "How many parameters does BERT large have?"
context = "BERT-large is really big. It has 24-layers and an embedding size of 1,024, for a total of 340M parameters! Altogether it is 1.34GB, so expect it to take a couple minutes to download to your Colab instance."

In [48]:
input_ids = tokenizer.encode(question, context)
print('The input has a total of {:} tokens.'.format(len(input_ids)))

The input has a total of 67 tokens.


In [49]:
tokens = tokenizer.convert_ids_to_tokens(input_ids)

for token, id in zip(tokens, input_ids):
    if id == tokenizer.sep_token_id:
        print('')

    print('{:<12} {:>6,}'.format(token, id))

    if id == tokenizer.sep_token_id:
        print('')

[CLS]           101
how           2,129
many          2,116
parameters   11,709
does          2,515
bert         14,324
large         2,312
have          2,031
?             1,029

[SEP]           102

bert         14,324
-             1,011
large         2,312
is            2,003
really        2,428
big           2,502
.             1,012
it            2,009
has           2,038
24            2,484
-             1,011
layers        9,014
and           1,998
an            2,019
em            7,861
##bed         8,270
##ding        4,667
size          2,946
of            1,997
1             1,015
,             1,010
02            6,185
##4           2,549
,             1,010
for           2,005
a             1,037
total         2,561
of            1,997
340          16,029
##m           2,213
parameters   11,709
!               999
altogether   10,462
it            2,009
is            2,003
1             1,015
.             1,012
34            4,090
##gb         18,259
,             1,01

In [50]:
sep_index = input_ids.index(tokenizer.sep_token_id)

num_seq_a = sep_index + 1

num_seq_b = len(input_ids) - num_seq_a

segment_ids = [0]*num_seq_a + [1]*num_seq_b

assert len(segment_ids) == len(input_ids)

In [51]:
outputs = model(torch.tensor([input_ids]),
                token_type_ids = torch.tensor([segment_ids]),
                return_dict = True)

start_scores = outputs.start_logits
end_scores = outputs.end_logits

In [52]:
answer_start = 0
answer_end = 0

max_score = float('-inf')
for start_idx in range(len(start_scores[0])):
    for end_idx in range(len(end_scores[0])):
        if end_idx >= start_idx:
            pair_score = start_scores[0][start_idx] + end_scores[0][end_idx]
            if pair_score > max_score:
                max_score = pair_score
                answer_start = start_idx
                answer_end = end_idx

#answer = ' '.join(tokens[answer_start:answer_end+1])
answer = ""

for i in range(answer_start, answer_end + 1):
    
    # If it's a subword token, then recombine it with the previous token.
    if tokens[i][0:2] == '##':
        answer += tokens[i][2:]
    
    # Otherwise, add a space then the token.
    else:
        answer += ' ' + tokens[i]

print('Answer: "' + answer.strip() + '"')

Answer: "340m"


In [56]:
ex_context = '''
The Indian Premier League (IPL) (also known as the TATA IPL for sponsorship reasons) is a men's Twenty20 (T20) cricket league that is annually held in India. The league, which was founded by the BCCI in 2007, is contested by ten city-based franchise teams.[3][4] The IPL is usually held in summer between March and May every year. It has an exclusive window in the ICC Future Tours Programme, meaning fewer international cricket tours happening during IPL seasons
'''

ex_ques = "When was IPL founded?"

In [57]:
input_ids = tokenizer.encode(ex_ques, ex_context)
sep_index = input_ids.index(tokenizer.sep_token_id)
num_seq_a = sep_index + 1
num_seq_b = len(input_ids) - num_seq_a
segment_ids = [0]*num_seq_a + [1]*num_seq_b
tokens = tokenizer.convert_ids_to_tokens(input_ids)
outputs = model(torch.tensor([input_ids]),
                token_type_ids = torch.tensor([segment_ids]),
                return_dict = True)

start_scores = outputs.start_logits
end_scores = outputs.end_logits


answer_start = 0
answer_end = 0

max_score = float('-inf')
for start_idx in range(len(start_scores[0])):
    for end_idx in range(len(end_scores[0])):
        if end_idx >= start_idx:
            pair_score = start_scores[0][start_idx] + end_scores[0][end_idx]
            if pair_score > max_score:
                max_score = pair_score
                answer_start = start_idx
                answer_end = end_idx
answer = ""

for i in range(answer_start, answer_end + 1):
    
    # If it's a subword token, then recombine it with the previous token.
    if tokens[i][0:2] == '##':
        answer += tokens[i][2:]
    
    # Otherwise, add a space then the token.
    else:
        answer += ' ' + tokens[i]

print('Answer: "' + answer.strip() + '"')

Answer: "2007"
