In [None]:
#for fair comparison with the KB we used, scispacy==0.4.0 should be used

In [None]:
from scispacy.linking import EntityLinker
import spacy, scispacy
import pandas as pd
from helpers import *
from tqdm import tqdm

# code for setting up the MeSH linker referred from https://github.com/allenai/scispacy/issues/355

config = {
    "resolve_abbreviations": True,  
    "linker_name": "mesh", 
    "max_entities_per_mention":1
}

nlp = spacy.load("en_ner_bc5cdr_md")
nlp.add_pipe("scispacy_linker", config=config) 

In [None]:
linker = nlp.get_pipe("scispacy_linker")

def extract_mesh_ids(text):
    mesh_entity_pairs = []
    doc = nlp(text)
    for e in doc.ents:
        if e._.kb_ents:
            cui = e._.kb_ents[0][0]
            mesh_entity_pairs.append({"text": e.text, "identifier": cui})
        else:
            mesh_entity_pairs.append({"text": e.text, "identifier": "None"})
    
    return mesh_entity_pairs

In [None]:
test_set_subsample = read_jsonl_file("test_set_subsample.jsonl")[0:100]


In [None]:
all_mesh_ids = []
for item in tqdm(test_set_subsample):
    text = item["title"] + " " + item["abstract"]
    mesh_ids = extract_mesh_ids(text)
    all_mesh_ids.append(mesh_ids)

In [None]:
entity_scores = [calculate_entity_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, all_mesh_ids)]
macro_precision_entity = sum([x[0] for x in entity_scores]) / len(entity_scores)
macro_recall_entity = sum([x[1] for x in entity_scores]) / len(entity_scores)
macro_f1_entity = sum([x[2] for x in entity_scores]) / len(entity_scores)

In [None]:
mesh_scores = [calculate_mesh_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, all_mesh_ids)]
macro_precision_mesh = sum([x[0] for x in mesh_scores]) / len(entity_scores)
macro_recall_mesh = sum([x[1] for x in mesh_scores]) / len(entity_scores)
macro_f1_mesh = sum([x[2] for x in mesh_scores]) / len(entity_scores)

In [None]:
scores_df = pd.DataFrame([["Entity Extraction", macro_precision_entity, macro_recall_entity, macro_f1_entity], ["Entity Linking", macro_precision_mesh, macro_recall_mesh,macro_f1_mesh]])
scores_df.columns = ["Task", "Macro Precision", "Macro Recall", "Macro F1"]
scores_df['Macro Precision'] = scores_df['Macro Precision'].apply(lambda x: f'{x * 100:.2f}%')
scores_df['Macro Recall'] = scores_df['Macro Recall'].apply(lambda x: f'{x * 100:.2f}%')
scores_df['Macro F1'] = scores_df['Macro F1'].apply(lambda x: f'{x * 100:.2f}%')
scores_df.to_csv("scispacy_scores.csv", index=False)

In [None]:
with open("scispacy_predictions.json", "w") as file:
    file.write(json.dumps({"predictions": all_mesh_ids}))