<a href="https://colab.research.google.com/github/kalpanab-psg/BERT/blob/main/BERT_Q_A_checking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BERT Question - Answering
Lets understand how we can apply a fine-tuned BERT to question answering tasks i.e given a question and a passage containing the answer, the task is to predict the answer text span in the paragraph given.

### Load question answering and transformers


In [1]:
!pip install transformers



In [2]:
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')


Note: The BertForQuestionAnswering class supports fine-tunning. We can fine-tune this model on our own dataset.

Create a QA example and use function encode_plus() to encode the example. The function encode_plus() returns a dictionary that contains input_ids, token_type_ids, and attention mask but we only need input_ids and token_type_ids for the QA task.

In [3]:
paragraph = ''' Prim's (also known as Jarník's) algorithm is a greedy algorithm that finds a minimum spanning tree for a weighted undirected graph. 
                This means it finds a subset of the edges that forms a tree that includes every vertex, where the total weight of all the edges in 
                the tree is minimized. The algorithm operates by building this tree one vertex at a time, from an arbitrary starting vertex, at each
                 step adding the cheapest possible connection from the tree to another vertex.'''


question = '''prims algorithm works on what type of graph?'''


In [4]:
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph, add_special_tokens=True)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

print("\nToken embeddings")
print(inputs)
print("\nSentence Embedding")
print(sentence_embedding)
print("\nTokens")
print(tokens)


Token embeddings
[101, 26927, 5244, 9896, 2573, 2006, 2054, 2828, 1997, 10629, 1029, 102, 26927, 2213, 1005, 1055, 1006, 2036, 2124, 2004, 15723, 8238, 1005, 1055, 1007, 9896, 2003, 1037, 20505, 9896, 2008, 4858, 1037, 6263, 13912, 3392, 2005, 1037, 18215, 6151, 7442, 10985, 10629, 1012, 2023, 2965, 2009, 4858, 1037, 16745, 1997, 1996, 7926, 2008, 3596, 1037, 3392, 2008, 2950, 2296, 19449, 1010, 2073, 1996, 2561, 3635, 1997, 2035, 1996, 7926, 1999, 1996, 3392, 2003, 18478, 2094, 1012, 1996, 9896, 5748, 2011, 2311, 2023, 3392, 2028, 19449, 2012, 1037, 2051, 1010, 2013, 2019, 15275, 3225, 19449, 1010, 2012, 2169, 3357, 5815, 1996, 10036, 4355, 2825, 4434, 2013, 1996, 3392, 2000, 2178, 19449, 1012, 102]

Sentence Embedding
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

Note: In the case of multiple QA examples, we’ll need to make all the vectors the same size by padding shorter sentences with the token id 0.

Run the QA example through the loaded model.

In [5]:
start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))
print("start_scores\n ",start_scores)
print("end scores\n", end_scores)

start_scores
  tensor([[-6.3971, -8.3209, -9.1853, -8.6808, -8.2087, -7.9485, -6.6165, -8.8133,
         -8.8498, -9.0359, -9.9972, -6.3971, -3.8731, -7.7654, -8.4206, -7.8233,
         -7.7760, -7.5092, -8.4712, -8.7257, -6.4428, -8.4916, -8.6429, -8.1336,
         -7.7421, -6.0585, -6.7516, -5.2587, -3.3205, -5.9829, -6.9598, -4.8603,
         -5.6328, -4.3468, -5.8368, -5.5401, -3.1907,  5.6977,  7.3329,  6.1563,
         -2.7407, -2.2136, -1.5076, -5.6443, -6.2299, -7.4219, -6.7675, -6.2468,
         -6.7222, -6.4435, -8.4497, -8.0244, -6.3336, -8.6385, -7.5787, -6.2134,
         -4.9520, -8.5590, -7.7317, -7.3434, -5.6952, -7.9661, -7.1608, -6.6087,
         -6.4694, -4.7526, -8.7338, -8.2227, -8.8020, -6.9529, -8.8522, -7.7709,
         -5.6715, -8.5009, -5.7772, -7.7641, -7.5371, -6.3379, -6.8356, -7.2648,
         -6.6319, -5.7156, -7.3248, -5.7208, -7.3504, -7.0822, -8.7211, -8.7740,
         -7.9832, -8.0928, -7.2924, -7.9072, -6.0518, -7.1756, -6.8299, -8.3008,
         -7.1



Now we have start scores and end scores we can get both the start index and the end index and use both the indices for span prediction.


In [6]:
start_index = torch.argmax(start_scores)

end_index = torch.argmax(end_scores)
print("start index: ",start_index, " End index: ", end_index)


answer = ' '.join(tokens[start_index:end_index+1])
print(answer)

start index:  tensor(38)  End index:  tensor(41)
weighted und ##ire ##cted




Note: The model is likely to predict an end word that is before the start word. The correct way is to pick a span for which the total score (start score + end score) is maximum where end_index ≥ start_index.

Note: BERT uses wordpiece tokenization. Wordpiece split the tokens like “playing” to “play and ##ing”. It also covers a wider spectrum of Out-Of-Vocabulary (OOV) words.

We can recover any words that were broken down into subwords with a little bit more work

In [7]:
corrected_answer = ''

for word in answer.split():
    
    #If it's a subword token
    if word[0:2] == '##':
        corrected_answer += word[2:]
    else:
        corrected_answer += ' ' + word

In [8]:
print(question)
print(corrected_answer)

prims algorithm works on what type of graph?
 weighted undirected


## Questions asked on this text 

###question = '''prims algorithm works on what type of graph?'''
###question = ''' how does Jarnik algorithm build a tree? '''
###question = ''' how does Jarnik algorithm choose the next node of  a tree? '''
###question = ''' What is a spanning tree? '''
###question = ''' In a graph, if two nodes available how does Prims algorithm choose the next node? '''
###question = ''' What is the other name for Prim's algorithm'''
###question = ''' What category of algorithm does Prim fall into? '''

### Answers Given by the model
- prims algorithm works on what type of graph?   
 weighted undirected

- how does Jarnik algorithm build a tree?   
 one vertex at a time , from an arbitrary starting vertex
- how does Jarnik algorithm choose the next node of  a tree? 
 adding the cheapest possible connection from the tree to another vertex

- What is a spanning tree? 
 a subset of the edges that forms a tree that includes every vertex
- In a graph, if two nodes available how does Prims algorithm coose the next node? 
 adding the cheapest possible connection
- What is the other name for Prim's algorithm     
   jarnik ' s

- What category of algorithm does Prim fall into? 

 greedy



## Questions not answered correctly by the model

- Is prims and Jarniks same? 

 prim ' s  ( Answer returned)

 It does not have a choice to say 'yes'. The answers should be available in the given text
