In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
import random
import nltk
from peft import AutoPeftModelForCausalLM

random.seed(42)
from helpers import *
from retriever import *
nltk.download('punkt')
import pandas as pd

## Load Model

In [None]:
model = AutoPeftModelForCausalLM.from_pretrained("entity_finetune/",  torch_dtype=torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model.eval()

## Load test set subsample

In [None]:
test_set_subsample = read_jsonl_file("test_set_subsample.jsonl")
few_shot_example = read_jsonl_file("mesh_few_shot_prompt.jsonl")

# Evaluate fine-tuned Mistral performance 

In [None]:
mistral_few_shot_answers = []
for item in tqdm(test_set_subsample):
    few_shot_prompt_messages = build_entity_prompt(item)
    prompt = tokenizer.apply_chat_template(few_shot_prompt_messages, tokenize=False)
    tensors = tokenizer(prompt, return_tensors="pt")
    input_ids = tensors.input_ids.cuda()
    attention_mask = tensors.attention_mask.cuda()
    outputs = model.generate(input_ids = input_ids, attention_mask = attention_mask, max_new_tokens=200, do_sample=False)    
    # https://github.com/huggingface/transformers/issues/17117#issuecomment-1124497554
    gen_text = tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    mistral_few_shot_answers.append(parse_entities_from_trained_model(gen_text.strip()))

## Set up BM-25 retriever

In [None]:
mesh_data = read_jsonl_file("mesh_2020.jsonl")
process_mesh_kb(mesh_data)
mesh_data_kb = {x["concept_id"]:x for x in mesh_data}
entity_mesh_data_dict = [[x["concept_id"] , " ".join(x["aliases"].split(",")) + " " + x["canonical_name"]] for x in mesh_data]
entity_ranker = BM25Retriever(entity_mesh_data_dict)

### Evaluate Finetuned LLM + Retriever Performance

In [None]:
retrieved_answers = []

for item in tqdm(mistral_few_shot_answers):
    answer_element = []
    for entity in item:
        retrieved_mesh_ids = entity_ranker.query(entity, top_n = 1)
        answer_element.append({"identifier":retrieved_mesh_ids[0], "text":entity})
    retrieved_answers.append(answer_element)

In [None]:
entity_scores = [calculate_entity_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, retrieved_answers)]
macro_precision_entity = sum([x[0] for x in entity_scores]) / len(entity_scores)
macro_recall_entity = sum([x[1] for x in entity_scores]) / len(entity_scores)
macro_f1_entity = sum([x[2] for x in entity_scores]) / len(entity_scores)

In [None]:
mesh_scores = [calculate_mesh_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, retrieved_answers)]
macro_precision_mesh = sum([x[0] for x in mesh_scores]) / len(mesh_scores)
macro_recall_mesh = sum([x[1] for x in mesh_scores]) / len(mesh_scores)
macro_f1_mesh = sum([x[2] for x in mesh_scores]) / len(mesh_scores)

In [None]:
scores_df = pd.DataFrame([["Entity Extraction", macro_precision_entity, macro_recall_entity, macro_f1_entity], ["Entity Linking", macro_precision_mesh, macro_recall_mesh,macro_f1_mesh]])
scores_df.columns = ["Task", "Macro Precision", "Macro Recall", "Macro F1"]
scores_df['Macro Precision'] = scores_df['Macro Precision'].apply(lambda x: f'{x * 100:.2f}%')
scores_df['Macro Recall'] = scores_df['Macro Recall'].apply(lambda x: f'{x * 100:.2f}%')
scores_df['Macro F1'] = scores_df['Macro F1'].apply(lambda x: f'{x * 100:.2f}%')
scores_df.to_csv("finetuned_model_scores.csv", index=False)

In [None]:
with open("finetuned_predictions.json", "w") as file:
    file.write(json.dumps({"predictions": retrieved_answers}))