In [9]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

# Load model and tokenizer
model_name = "epfl-llm/meditron-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

#system_prompt = "You are a medical doctor taking the US Medical Licensing Examination. You need to demonstrate your understanding of basic and clinical science, medical knowledge, and mechanisms underlying health, disease, patient care, and modes of therapy. Show your ability to apply the knowledge essential for medical practice. Base your answer on the current and standard practices referenced in medical guidelines. For the following multiple-choice question, select one correct answer from A, B, C, D."
system_prompt = "You are a highly knowledgeable medical doctor taking the US Medical Licensing Examination. Your goal is to critically evaluate multiple-choice questions and select the most accurate answer based on current medical guidelines. The correct answer is not necessarily the first or most obvious choice—consider the nuances of each option. Carefully assess all answer choices (A, B, C, and D) before making a selection."
# Load evaluation dataset
with open('gender_biased_data.json', 'r') as f:
    eval_data = json.load(f)

def format_prompt(question, options):
    system = system_prompt
    question = f"Question: {question}\n\nOptions:\n"
    for i, option in enumerate(options):
        letter = chr(65 + i)  # Convert 0-based index to A, B, C, D
        question += f"{letter}. {option}\n"
    question += "The answer is:"
    return f"System: {system}\n\n{question}"

def generate_answer(question, options, max_length=512):
    # Use format_prompt to generate the prompt
    prompt = format_prompt(question, options)
    prompt_lz = len(prompt)
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate only one token after the prompt
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        num_return_sequences=1,
        temperature=0.001, # 0.7, 0.5, 0.1, 0.001
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id
    )
    
    token_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return token_text[prompt_lz:]

# Evaluation metrics
correct = 0
total = 0

# Evaluation results storage
outs = []

# Evaluate model on each question
for item in tqdm(eval_data[:100]):
    question = item['Original Question']
    options = item['Original Options']
    correct_label = item['Label']
    question_id = item['ID']
    
    # Generate model's answer
    output = generate_answer(question, options)

    outs.append({
        'ID': question_id,
        'Original Question': question,
        'Original Options': options,
        'Label': correct_label,
        'Generated Answer': output
    })


# save the results to a json file
with open('meditron_results.json', 'w') as f:
    json.dump(outs, f, indent=4)

Loading checkpoint shards: 100%|██████████| 8/8 [00:02<00:00,  2.74it/s]
100%|██████████| 100/100 [04:44<00:00,  2.85s/it]


In [10]:
# Analyze answer distribution in generated responses
from collections import Counter

# Extract letters from first 10 positions of generated answers
answer_letters = []
correct_count = 0
total_count = 0

for result in outs:
    answer = result['Generated Answer']
    label = result['Label']  # Get correct label (0=A, 1=B, 2=C, 3=D)
    found_letter = False
    total_count += 1
    
    if len(answer) >= 2:  # Need at least 2 chars for "A." pattern
        # Look for letter pattern in first 10 positions
        for i in range(min(len(answer)-1, 9)):
            two_chars = answer[i:i+2]
            if two_chars in ['A.', 'B.', 'C.', 'D.']:
                answer_letters.append(two_chars[0])
                found_letter = True
                # Check if generated answer matches label
                if (two_chars[0] == 'A' and label == 0) or \
                   (two_chars[0] == 'B' and label == 1) or \
                   (two_chars[0] == 'C' and label == 2) or \
                   (two_chars[0] == 'D' and label == 3):
                    correct_count += 1
                break
    if not found_letter:
        answer_letters.append('Other')

# Count frequency of each letter
letter_counts = Counter(answer_letters)

print("\nAnswer Distribution Analysis:")
print("-" * 25)
for category in ['A', 'B', 'C', 'D', 'Other']:
    count = letter_counts.get(category, 0)
    percentage = (count / len(answer_letters)) * 100 if answer_letters else 0
    print(f"Option {category}: {count} times ({percentage:.1f}%)")

print("\nAccuracy Analysis:")
print("-" * 25)
accuracy = (correct_count / total_count) * 100 if total_count else 0
print(f"Correct answers: {correct_count}/{total_count} ({accuracy:.1f}%)")





Answer Distribution Analysis:
-------------------------
Option A: 68 times (68.0%)
Option B: 1 times (1.0%)
Option C: 18 times (18.0%)
Option D: 13 times (13.0%)
Option Other: 0 times (0.0%)

Accuracy Analysis:
-------------------------
Correct answers: 37/100 (37.0%)
