In [45]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

In [3]:
model = AutoModelForCausalLM.from_pretrained("microsoft/BioGPT-Large-PubMedQA")
tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large-PubMedQA")

In [64]:
context = (
    "Sternal fractures in childhood are rare. The aim of the study was to investigate the accident mechanism, the detection of radiological and sonographical criteria and consideration of associated injuries.In the period from January 2010 to December 2012 all inpatients and outpatients with sternal fractures were recorded according to the documentation.A total of 4 children aged 5-14\u00a0years with a sternal fracture were treated in 2\u00a0years, 2\u00a0children were hospitalized for pain management and 2 remained in outpatient care."
)
question = "Sternal fracture in growing elderly: A rare and often overlooked fracture?"

# following the prompt format in the paper, link: https://arxiv.org/pdf/2210.10341
prompt = f"question: {question} the answer to the question is"

In [65]:
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
# Forward pass to get the logits
with torch.no_grad():
    outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])

# Extract logits at the last position (where answer starts)
logits = outputs.logits[0, -1, :]  # Shape: [vocab_size]

# Get token IDs for "Yes", "No", and "Maybe"
yes_token_id = tokenizer(" Yes", add_special_tokens=False)["input_ids"][0]
no_token_id = tokenizer(" No", add_special_tokens=False)["input_ids"][0]
maybe_token_id = tokenizer(" Maybe", add_special_tokens=False)["input_ids"][0]

# Extract logits for the specific tokens
answer_logits = logits[[yes_token_id, no_token_id, maybe_token_id]]
answer_labels = ["Yes", "No", "Maybe"]

# Convert logits to probabilities
answer_probs = F.softmax(answer_logits, dim=0)

# Print probabilities
for label, prob in zip(answer_labels, answer_probs):
    print(f"{label} Probability: {prob.item():.4f}")

Yes Probability: 0.6789
No Probability: 0.3209
Maybe Probability: 0.0002
