In [1]:
import sys
sys.path.append('..')

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

from dsets import ClinicalAgeGroupDataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = AutoModelForCausalLM.from_pretrained("microsoft/BioGPT-Large-PubMedQA")
tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large-PubMedQA")

In [8]:
dataset = ClinicalAgeGroupDataset('../data')
opposite_age_group = {
    'infant': 'elderly',
    'children': 'adults',
    'adults': 'children',
    'elderly': 'infant'
}

Loaded dataset with 93 samples


In [9]:
data = dataset[0]
age_replacement = opposite_age_group[data['age_group']]
corrupted_question = data['question'].replace(data['age_group'], age_replacement)
corrupted_context = data['context'].replace(data['age_group'], age_replacement)

original_prompt = f"Question: {data['question']} Context: {data['context']} the answer to the question given the context is"
corrupted_prompt = f"Question: {corrupted_question} Context: {corrupted_context} the answer to the question given the context is"

In [10]:
def get_answer_prob(prompt):
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    # Forward pass to get the logits
    with torch.no_grad():
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])

    # Extract logits at the last position (where answer starts)
    logits = outputs.logits[0, -1, :]  # Shape: [vocab_size]

    # Get token IDs for "Yes", "No", and "Maybe"
    yes_token_id = tokenizer(" Yes", add_special_tokens=False)["input_ids"][0]
    no_token_id = tokenizer(" No", add_special_tokens=False)["input_ids"][0]
    maybe_token_id = tokenizer(" Maybe", add_special_tokens=False)["input_ids"][0]

    # Extract logits for the specific tokens
    answer_logits = logits[[yes_token_id, no_token_id, maybe_token_id]]
    answer_labels = ["Yes", "No", "Maybe"]

    # Convert logits to probabilities
    answer_probs = F.softmax(answer_logits, dim=0)

    return dict(zip(answer_labels, answer_probs.tolist()))

In [11]:
get_answer_prob(original_prompt), get_answer_prob(corrupted_prompt)

({'Yes': 0.9996758699417114,
  'No': 0.00032262789318338037,
  'Maybe': 1.5754862943140324e-06},
 {'Yes': 0.9999408721923828,
  'No': 5.877300282008946e-05,
  'Maybe': 3.77662559003511e-07})