In [None]:
import torch
from transformers import BertForQuestionAnswering, BertTokenizer

# Load model and tokenizer
model_name = 'bert-large-uncased-whole-word-masking-finetuned-squad'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)

# Function to get answer
def get_answer(question, context):
    inputs = tokenizer.encode_plus(question, context, return_tensors='pt', truncation=True, max_length=512)
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        start_logits = outputs.start_logits
        end_logits = outputs.end_logits

    # Get the most likely beginning and end of answer
    start_index = torch.argmax(start_logits)
    end_index = torch.argmax(end_logits) + 1

    # Convert tokens to answer string
    answer_ids = input_ids[0][start_index:end_index]
    answer = tokenizer.decode(answer_ids, skip_special_tokens=True)
    return answer




In [None]:
# Main interaction
context = input("Enter context: ")


In [None]:
question = input("Enter your question: ")
answer = get_answer(question, context)
print("\nAnswer:", answer)