In [1]:
from transformers import BertForQuestionAnswering, BertTokenizer
import torch

# Load the pre-trained BERT model and tokenizer
model_name = 'bert-large-uncased-whole-word-masking-finetuned-squad' # Fine-tuned BERT model for Question Answering
model = BertForQuestionAnswering.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
# Context of the story
story = """
snow white short story Snow White Short Story
The Witch’s Evil Plan
One day, the witch asked her magic mirror a question. Mirror! Mirror! on the wall who is the most beautiful woman in this kingdom? And the mirror said, “Snow White” it is. The witch queen got very angry on hearing this. She ordered a soldier to take her away from the kingdom and kill her. Don’t forget to check out more Snow White Videos & Fun Facts.
The witch said, “Beautiful child! take this delicious apple. I am sure that you have not tried anything like this before.” She took the apple and bit it without knowing the evil witch has poisoned the apple. She fell on the floor after taking the bite and never woke up. The little dwarfs got very sad and decided to place her on the flower bed.
"""

# Question to ask about the story
question = "Who ordered the soldier to take Snow White away from the kingdom?"

In [7]:
def robot(contex, question):
    # Tokenize the input text and question
    inputs = tokenizer(question, contex, return_tensors="pt")

    # Get the answer
    output = model(**inputs)

    # Get start and end scores for answer
    answer_start_scores = output.start_logits
    answer_end_scores = output.end_logits

    # Find start and end of answer
    answer_start = torch.argmax(answer_start_scores)  # Get the most likely beginning of answer with the argmax of the score
    answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score

    # Convert tokens to string
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
    return answer

In [12]:
robot(story, "what's kind of fruit that witch give to snow white?")

'apple'

In [14]:
robot(story, "with who the witch asked for?")


'magic mirror'

In [15]:
robot(story, "who is the most beautiful woman?")

'snow white'