In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
import random
import nltk

random.seed(42)
from helpers import *
from retriever import *
import matplotlib.pyplot as plt
nltk.download('punkt')

## Load Model

In [None]:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",  torch_dtype=torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model.eval()

## Create test set subsample

In [None]:
# train_dataset = parse_dataset("CDR_TrainingSet.PubTator.txt")
# deduplicate_annotations(train_dataset)

# test_dataset = parse_dataset("CDR_TestSet.PubTator.txt")
# deduplicate_annotations(test_dataset)

In [None]:
# few_shot_prompt = [train_dataset[0] , train_dataset[10] , train_dataset[100]]
# test_set_subsample = random.sample(test_dataset, 200)

In [None]:
# write_jsonl_file("mesh_few_shot_prompt.jsonl",few_shot_prompt)
# write_jsonl_file("test_set_subsample.jsonl",test_set_subsample)

## Load test set subsample

In [None]:
test_set_subsample = read_jsonl_file("test_set_subsample.jsonl")
few_shot_example = read_jsonl_file("mesh_few_shot_prompt.jsonl")

# Evaluate zero-shot Mistral performance 

In [None]:
mistral_few_shot_answers = []
for item in tqdm(test_set_subsample):
    few_shot_prompt_messages = build_few_shot_prompt(SYSTEM_PROMPT, item, few_shot_example)
    input_ids = tokenizer.apply_chat_template(few_shot_prompt_messages, tokenize=True, return_tensors = "pt").cuda()
    outputs = model.generate(input_ids = input_ids, max_new_tokens=200, do_sample=False)    
    # https://github.com/huggingface/transformers/issues/17117#issuecomment-1124497554
    gen_text = tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    mistral_few_shot_answers.append(parse_answer(gen_text.strip()))

In [None]:
with open("mistral_zero_shot_predictions.json", "w") as file:
    file.write(json.dumps({"predictions": mistral_few_shot_answers}))

## Evaluate zero-shot performance

In [None]:
entity_scores = [calculate_entity_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, mistral_few_shot_answers)]

In [None]:
macro_precision_entity = sum([x[0] for x in entity_scores]) / len(entity_scores)
macro_recall_entity = sum([x[1] for x in entity_scores]) / len(entity_scores)
macro_f1_entity = sum([x[2] for x in entity_scores]) / len(entity_scores)

In [None]:
mesh_scores = [calculate_mesh_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, mistral_few_shot_answers)]

In [None]:
macro_precision_mesh = sum([x[0] for x in mesh_scores]) / len(mesh_scores)
macro_recall_mesh = sum([x[1] for x in mesh_scores]) / len(mesh_scores)
macro_f1_mesh = sum([x[2] for x in mesh_scores]) / len(mesh_scores)

In [None]:
scores_df = pd.DataFrame([["Entity Extraction", macro_precision_entity, macro_recall_entity, macro_f1_entity], ["Entity Linking", macro_precision_mesh, macro_recall_mesh,macro_f1_mesh]])
scores_df.columns = ["Task", "Macro Precision", "Macro Recall", "Macro F1"]
scores_df['Macro Precision'] = scores_df['Macro Precision'].apply(lambda x: f'{x * 100:.2f}%')
scores_df['Macro Recall'] = scores_df['Macro Recall'].apply(lambda x: f'{x * 100:.2f}%')
scores_df['Macro F1'] = scores_df['Macro F1'].apply(lambda x: f'{x * 100:.2f}%')
scores_df.to_csv("zero_shot_entity_mesh_scores.csv", index=False)

## Set up BM-25 retriever

In [None]:
mesh_data = read_jsonl_file("mesh_2020.jsonl")
process_mesh_kb(mesh_data)
mesh_data_kb = {x["concept_id"]:x for x in mesh_data}
mesh_data_dict = process_index({x["concept_id"]:x for x in mesh_data})
entity_mesh_data_dict = [[x["concept_id"] , " ".join(x["aliases"].split(",")) + " " + x["canonical_name"]] for x in mesh_data]

In [None]:
retriever = BM25Retriever(mesh_data_dict)
entity_retriever = BM25Retriever(entity_mesh_data_dict)

### Evaluate Zero-Shot + Retriever Performance

In [None]:
parsed_entities_few_shot = [[y["text"] for y in x] for x in mistral_few_shot_answers]

In [None]:
retrieved_answers = []

for item in tqdm(parsed_entities_few_shot):
    answer_element = []
    for entity in item:
        retrieved_mesh_ids = entity_retriever.query(entity, top_n = 1)
        answer_element.append({"text": entity, "identifier":retrieved_mesh_ids[0]})
    retrieved_answers.append(answer_element)

In [None]:
with open("mistral_zero_shot_predictions_external_retriever.json", "w") as file:
    file.write(json.dumps({"predictions": retrieved_answers}))

In [None]:
mesh_scores = [calculate_mesh_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, retrieved_answers)]

In [None]:
macro_precision_mesh = sum([x[0] for x in mesh_scores]) / len(entity_scores)
macro_recall_mesh = sum([x[1] for x in mesh_scores]) / len(entity_scores)
macro_f1_mesh = sum([x[2] for x in mesh_scores]) / len(entity_scores)

In [None]:
scores_df = pd.DataFrame([["Entity Extraction", macro_precision_entity, macro_recall_entity, macro_f1_entity], ["Entity Linking", macro_precision_mesh, macro_recall_mesh,macro_f1_mesh]])
scores_df.columns = ["Task", "Macro Precision", "Macro Recall", "Macro F1"]
scores_df['Macro Precision'] = scores_df['Macro Precision'].apply(lambda x: f'{x * 100:.2f}%')
scores_df['Macro Recall'] = scores_df['Macro Recall'].apply(lambda x: f'{x * 100:.2f}%')
scores_df['Macro F1'] = scores_df['Macro F1'].apply(lambda x: f'{x * 100:.2f}%')
scores_df.to_csv("zero_shot_entity_external_retriever_scores.csv", index=False)

In [None]:
with open("mistral_zero_shot_entity_retriever_predictions.json", "w") as file:
    file.write(json.dumps({"predictions": retrieved_answers}))

## Evaluate RAG performance

In [None]:
coverage_dict = {}
ground_truth_ids = []
retrieved_ids = []

for item in tqdm(test_set_subsample):
    relevant_mesh_ids = retriever.query(item["title"] + " " + item["abstract"], top_n = 50)
    gt_ids = [x["identifier"] for x in item["annotations"]]
    ground_truth_ids.append(gt_ids)
    retrieved_ids.append(relevant_mesh_ids)
    

In [None]:
percent_coverage_dict = {10:[], 30:[], 50:[]}
for gt, pred in zip(ground_truth_ids, retrieved_ids):
    for k in [10,30,50]:
        reqd_pred = pred[0:k]
        percent_gt_in_retrieved = set(gt).intersection(set(reqd_pred))
        percent_coverage_dict[k].append(len(percent_gt_in_retrieved) / len(gt))

for key in percent_coverage_dict:
    percent_coverage_dict[key] = (sum(percent_coverage_dict[key]) / len(percent_coverage_dict[key])) * 100

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(percent_coverage_dict.keys(), percent_coverage_dict.values(), marker='o', linestyle='-', color='b')
plt.title('On average, what % of Ground Truth IDs is present in the fetched results?')
plt.xlabel('No. of Retrieved IDs')
plt.ylabel('Avg proportion of retrieved GT IDs')
plt.grid(True)
plt.xticks(list(percent_coverage_dict.keys()))
plt.show()


In [None]:
mistral_rag_answers = {10:[], 30:[], 50:[]}

for k in [10,30,50]:
    for item in tqdm(test_set_subsample):
        relevant_mesh_ids = retriever.query(item["title"] + " " + item["abstract"], top_n = k)
        relevant_contexts = [mesh_data_kb[x] for x in relevant_mesh_ids]
        rag_prompt = build_rag_prompt(SYSTEM_RAG_PROMPT, item, relevant_contexts)
        input_ids = tokenizer.apply_chat_template(rag_prompt, tokenize=True, return_tensors = "pt").cuda()
        outputs = model.generate(input_ids = input_ids, max_new_tokens=200, do_sample=False)    
        # https://github.com/huggingface/transformers/issues/17117#issuecomment-1124497554
        gen_text = tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
        mistral_rag_answers[k].append(parse_answer(gen_text.strip()))

In [None]:
entity_scores_at_k = {}
mesh_scores_at_k = {}
df_list = []

for key, value in mistral_rag_answers.items():
    entity_scores = [calculate_entity_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, value)]
    macro_precision_entity = sum([x[0] for x in entity_scores]) / len(entity_scores)
    macro_recall_entity = sum([x[1] for x in entity_scores]) / len(entity_scores)
    macro_f1_entity = sum([x[2] for x in entity_scores]) / len(entity_scores)
    entity_scores_at_k[key] = {"macro-precision": macro_precision_entity, "macro-recall": macro_recall_entity, "macro-f1": macro_f1_entity}
    
    mesh_scores = [calculate_mesh_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, value)]
    macro_precision_mesh = sum([x[0] for x in mesh_scores]) / len(mesh_scores)
    macro_recall_mesh = sum([x[1] for x in mesh_scores]) / len(mesh_scores)
    macro_f1_mesh = sum([x[2] for x in mesh_scores]) / len(mesh_scores)
    mesh_scores_at_k[key] = {"macro-precision": macro_precision_mesh, "macro-recall": macro_recall_mesh, "macro-f1": macro_f1_mesh}

    scores_df = pd.DataFrame([["Entity Extraction", macro_precision_entity, macro_recall_entity, macro_f1_entity], ["Entity Linking", macro_precision_mesh, macro_recall_mesh,macro_f1_mesh]])
    scores_df.columns = ["Task", "Macro Precision", "Macro Recall", "Macro F1"]
    scores_df['Macro Precision'] = scores_df['Macro Precision'].apply(lambda x: f'{x * 100:.2f}%')
    scores_df['Macro Recall'] = scores_df['Macro Recall'].apply(lambda x: f'{x * 100:.2f}%')
    scores_df['Macro F1'] = scores_df['Macro F1'].apply(lambda x: f'{x * 100:.2f}%')
    df_list.append(scores_df)
    
writer = pd.ExcelWriter('results_rag.xlsx', engine='xlsxwriter')

# Write each DataFrame to a different worksheet
df_list[0].to_excel(writer, sheet_name='Rag@10', index=False)
df_list[1].to_excel(writer, sheet_name='Rag@30', index=False)
df_list[2].to_excel(writer, sheet_name='Rag@50', index=False)

writer.close()

In [None]:
with open("mistral_rag_predictions.json", "w") as file:
    file.write(json.dumps({"predictions": mistral_rag_answers}))

## Plot the entity scores as a function of number of retrieved documents

In [None]:
import matplotlib.pyplot as plt
x = list(entity_scores_at_k.keys())
y_precision = [details['macro-precision'] * 100 for details in entity_scores_at_k.values()]
y_recall = [details['macro-recall'] * 100 for details in entity_scores_at_k.values()]
y_f1 = [details['macro-f1'] * 100 for details in entity_scores_at_k.values()]

plt.figure(figsize=(10, 6))
plt.plot(x, y_precision, marker='o', linestyle='-', color='r', label='Macro-Precision')
plt.plot(x, y_recall, marker='^', linestyle='-', color='g', label='Macro-Recall')
plt.plot(x, y_f1, marker='s', linestyle='-', color='b', label='Macro-F1')

plt.title('Entity Extraction Performance')
plt.xlabel('No. of Retrieved IDs')
plt.ylabel('Scores (%)')
plt.grid(True)
plt.xticks(x)
plt.legend()
plt.show()

## Plot the MeSH linking scores as a function of number of retrieved documents

In [None]:
x = list(mesh_scores_at_k.keys())
y_precision = [details['macro-precision'] * 100 for details in mesh_scores_at_k.values()]
y_recall = [details['macro-recall'] * 100 for details in mesh_scores_at_k.values()]
y_f1 = [details['macro-f1'] * 100 for details in mesh_scores_at_k.values()]

plt.figure(figsize=(10, 6))
plt.plot(x, y_precision, marker='o', linestyle='-', color='r', label='Macro-Precision')
plt.plot(x, y_recall, marker='^', linestyle='-', color='g', label='Macro-Recall')
plt.plot(x, y_f1, marker='s', linestyle='-', color='b', label='Macro-F1')

plt.title('Entity Linking Performance')
plt.xlabel('No. of Retrieved IDs')
plt.ylabel('Scores (%)')
plt.grid(True)
plt.xticks(x)
plt.legend()
plt.show()