In [None]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

# Load model and tokenizer
model_name = "epfl-llm/meditron-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load evaluation dataset
with open('gender_biased_data.json', 'r') as f:
    eval_data = json.load(f)

def format_prompt(question, options):
    system = "You are a medical doctor taking the US Medical Licensing Examination. You need to demonstrate your understanding of basic and clinical science, medical knowledge, and mechanisms underlying health, disease, patient care, and modes of therapy. Show your ability to apply the knowledge essential for medical practice. For the following multiple-choice question, select one correct answer from A to D. Base your answer on the current and standard practices referenced in medical guidelines."
    question = f"Question: {question}\n\nOptions:\n"
    for i, option in enumerate(options):
        letter = chr(65 + i)  # Convert 0-based index to A, B, C, D
        question += f"{letter}. {option}\n"
    question += "The answer is:"
    return f"System: {system}\n\n{question}"

def generate_answer(question, options, max_length=512):
    # Use format_prompt to generate the prompt
    prompt = format_prompt(question, options)
    prompt_lz = len(prompt)
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate only one token after the prompt
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        num_return_sequences=1,
        temperature=0.001, # 0.7, 0.5, 0.1, 0.001
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id
    )
    
    token_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return token_text[prompt_lz:]

# Evaluation metrics
correct = 0
total = 0

# Evaluation results storage
outs = []

# Evaluate model on each question
for item in tqdm(eval_data):
    question = item['Original Question']
    options = item['Original Options']
    correct_label = item['Label']
    question_id = item['ID']
    
    # Generate model's answer
    output = generate_answer(question, options)

    outs.append({
        'ID': question_id,
        'Original Question': question,
        'Original Options': options,
        'Label': correct_label,
        'Generated Answer': output
    })


# save the results to a json file
with open('meditron_results.json', 'w') as f:
    json.dump(outs, f, indent=4)

Loading checkpoint shards: 100%|██████████| 8/8 [00:06<00:00,  1.16it/s]
 27%|██▋       | 281/1046 [12:03<32:36,  2.56s/it]