In [20]:
!pip install transformers



### Load question answering and transformers


In [21]:
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')


Note: The BertForQuestionAnswering class supports fine-tunning. We can fine-tune this model on our own dataset.

Create a QA example and use function encode_plus() to encode the example. The function encode_plus() returns a dictionary that contains input_ids, token_type_ids, and attention mask but we only need input_ids and token_type_ids for the QA task.

In [22]:
question = '''What is minimum spanning tree?'''

paragraph = ''' In computer science, Prim's algorithm is a greedy algorithm that finds a minimum spanning tree for a weighted undirected graph. 
                This means it finds a subset of the edges that forms a tree that includes every vertex, where the total weight of all the edges 
                in the tree is minimized.'''
            
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph, add_special_tokens=True)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

Note: In the case of multiple QA examples, we’ll need to make all the vectors the same size by padding shorter sentences with the token id 0.

Run the QA example through the loaded model.

In [23]:
start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))



Now we have start scores and end scores we can get both the start index and the end index and use both the indices for span prediction.


In [24]:
start_index = torch.argmax(start_scores)

end_index = torch.argmax(end_scores)

answer = ' '.join(tokens[start_index:end_index+1])
answer

'a subset of the edges that forms a tree that includes every vertex'



Note: The model is likely to predict an end word that is before the start word. The correct way is to pick a span for which the total score (start score + end score) is maximum where end_index ≥ start_index.

Note: BERT uses wordpiece tokenization. Wordpiece split the tokens like “playing” to “play and ##ing”. It also covers a wider spectrum of Out-Of-Vocabulary (OOV) words.

We can recover any words that were broken down into subwords with a little bit more work

In [25]:
corrected_answer = ''

for word in answer.split():
    
    #If it's a subword token
    if word[0:2] == '##':
        corrected_answer += word[2:]
    else:
        corrected_answer += ' ' + word

In [29]:
print(question)
print(corrected_answer)

What is minimum spanning tree?
 a subset of the edges that forms a tree that includes every vertex
