In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
from helpers import *

## Load Model

In [None]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf",  torch_dtype=torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model.eval()

## Load sub-sampled test set

In [None]:
questions = read_jsonl_file("USMLE_test_samples_300.jsonl")

## Parse ground-truth and store answers

In [None]:
ground_truth = []

for item in questions:
    ans_options = item["options"]
    correct_ans_option = ""
    for key,value in ans_options.items():
        if value == item["answer"]:
            correct_ans_option = key
            break
            
    ground_truth.append(correct_ans_option)

## Evaluate zero-shot LLama performance 

In [None]:
zero_shot_llama_answers = []
for item in tqdm(questions):
    zero_shot_prompt_messages = build_zero_shot_prompt(PROMPT, item)
    prompt = tokenizer.apply_chat_template(zero_shot_prompt_messages, tokenize=False)
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    outputs = model.generate(input_ids=input_ids, max_new_tokens=10, do_sample=False)
    
    # https://github.com/huggingface/transformers/issues/17117#issuecomment-1124497554
    gen_text = tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    zero_shot_llama_answers.append(gen_text.strip())

In [None]:
zero_shot_llama_predictions = [parse_answer(x) for x in zero_shot_llama_answers]

In [None]:
print(calculate_accuracy(ground_truth, zero_shot_llama_predictions))

## Evaluate few-shot LLama performance

In [None]:
few_shot_prompts = read_jsonl_file("USMLE_few_shot_samples.jsonl")

In [None]:
few_shot_llama_answers = []
for item in tqdm(questions):
    few_shot_prompt_messages = build_few_shot_prompt(PROMPT, item, few_shot_prompts)
    prompt = tokenizer.apply_chat_template(few_shot_prompt_messages, tokenize=False)
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    outputs = model.generate(input_ids=input_ids, max_new_tokens=10, do_sample=False)
    gen_text = tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    few_shot_llama_answers.append(gen_text.strip())

In [None]:
few_shot_llama_predictions = [parse_answer(x) for x in few_shot_llama_answers]

In [None]:
print(calculate_accuracy(ground_truth, few_shot_llama_predictions))

## Evaluate CoT LLama performance

In [None]:
cot_llama_answers = []
for item in tqdm(questions):
    cot_prompt = build_cot_prompt(COT_INSTRUCTION, item, COT_EXAMPLES)
    prompt = tokenizer.apply_chat_template(cot_prompt, tokenize=False)
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=False)
    gen_text = tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    cot_llama_answers.append(gen_text.strip())

In [None]:
cot_llama_predictions = [parse_answer_cot(x) for x in cot_llama_answers]

In [None]:
print(calculate_accuracy(ground_truth, cot_llama_predictions))

## Dump all outputs and results

In [None]:
zero_shot_llama_df = pd.DataFrame([[x,y] for x,y in zip(zero_shot_llama_answers, zero_shot_llama_predictions)])
zero_shot_llama_df.columns = ["Predicted String", "Extracted Option"]
zero_shot_llama_df.to_csv("llama_zero_shot_answers_dump.csv", index=False)

In [None]:
few_shot_llama_df = pd.DataFrame([[x,y] for x,y in zip(few_shot_llama_answers, few_shot_llama_predictions)])
few_shot_llama_df.columns = ["Predicted String", "Extracted Option"]
few_shot_llama_df.to_csv("llama_few_shot_answers_dump.csv", index=False)

In [None]:
cot_llama_df = pd.DataFrame([[x,y] for x,y in zip(cot_llama_answers, cot_llama_predictions)])
cot_llama_df.columns = ["Predicted String", "Extracted Option"]
cot_llama_df.to_csv("llama_cot_answers_dump.csv", index=False)