In [None]:
pip install transformers


In [None]:
pip install torch

In [None]:
from transformers import BertTokenizer, BertForQuestionAnswering
import torch

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

In [54]:
# Print total model params
total_params = sum(
	param.numel() for param in model.parameters()
)
print(f'number of parameters: {total_params}')


number of parameters: 334094338


In [44]:
def answer_question(question, context):
    """
    Finds an answer to a question within a given context using BERT QA.

    Parameters:
    question (str): The question to be answered.
    context (str): The context containing the answer to the question.

    Returns:
    str: The answer found within the context.
    """
    # Tokenize input question and context, including the attention mask automatically.
    inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt")
    input_ids = inputs["input_ids"]

    # Predict the start and end positions of the answer
    outputs = model(**inputs)
    answer_start_scores = outputs.start_logits
    answer_end_scores = outputs.end_logits

    # Find the positions with the highest start and end scores
    answer_start = torch.argmax(answer_start_scores)
    answer_end = torch.argmax(answer_end_scores) + 1

    # Decode the answer's start to end positions to tokens
    answer_tokens = tokenizer.convert_ids_to_tokens(input_ids[0][answer_start:answer_end], skip_special_tokens=True)

    # Convert the tokens to string
    answer = tokenizer.convert_tokens_to_string(answer_tokens)

    return answer


In [53]:
context = "A huge dog is barking in front of my door entrance. I'm terrified because I never lived with dogs before. My wolf is scary"
question1 = "why am I afraid?"
question2 = "who am I afraid of?"
question3 = "who is afraid?"

answer1 = answer_question(question1, context)
print(f"Question: {question1}")
print(f"Answer: {answer1}")

answer2 = answer_question(question2, context)
print(f"Question: {question2}")
print(f"Answer: {answer2}")

answer3 = answer_question(question3, context)
print(f"Question: {question3}")
print(f"Answer: {answer3}")

Question: why am I afraid?
Answer: i never lived with dogs before
Question: who am I afraid of?
Answer: my wolf
Question: who is afraid?
Answer: my wolf


Please note that the results may be affected by the model size, type, and upon the data which it was trained on. 
Feel free to try different models and experience how the results vary.