In [None]:
from pathlib import Path
import json

PROCESSED  = Path('data/processed')
(PROCESSED / "kbs").mkdir(parents=True, exist_ok=True)

RAW = Path('data/raw')
RAW.mkdir(parents=True, exist_ok=True)

## Step 1: Prepare KB for scispacy


#### Step 1.1: UMLS_SapBERT

UMLS filtered by WiKIMed Entries (Same as SAPBERT paper)

In [None]:
from collections import defaultdict
cui_aliases = defaultdict(list)

# taken from sapbert repo
with open("data/umls_onto_all_lang_cased_wikimed_only_399931.txt", encoding="utf-8") as file:
    for line in file:
        cui, name = line.strip("\n").split("||")
        cui_aliases[cui].append(name)
       

In [None]:
# file contains duplicates. Removing duplicated does not change metric score.
# cui_aliases = {i: list(set(j)) for i, j in cui_aliases.items()}

In [None]:
kb = []
for cui, aliases, in cui_aliases.items():
    kb.append({"concept_id": cui, "aliases": aliases[1:], "canonical_name": aliases[0], "definition": ""})

In [None]:


with open(PROCESSED / "kbs"/ "kb_from_sapbert.jsonl", 'w', encoding="utf-8") as outfile:
    for entry in kb:
        json.dump(entry, outfile, ensure_ascii=False)
        outfile.write('\n')

#### 1.2 UMLS_Wikidata

WIKIDATA using SPARQL

The step is optional. UMLS_Wikidata KB can be downloaded as mentioned in README and processed as per step 1.2.1

In [None]:
import sys
!{sys.executable} -m pip install qwikidata

Get the QIDs which have CUI and save as "qids_with_cui.csv" file.

https://w.wiki/8Fkw

In [None]:

from concurrent.futures import ThreadPoolExecutor

from qwikidata.entity import WikidataItem
from qwikidata.linked_data_interface import get_entity_dict_from_api
from pprint import pprint

import csv
import json
from tqdm import tqdm
from itertools import islice


def prepare_and_dump(qid, q_dict):
    q_item = WikidataItem(q_dict)
    aliases = q_item.get_aliases("de")

    
    cuis = [
        i._claim_dict["mainsnak"]["datavalue"]["value"]
        for i in q_item.get_claim_group("P2892")
    ]
   

    sample = {
        "qid": qid,
        "label": q_item.get_label("de"),
        "description": q_item.get_description("de"),
        "cui": cuis,
        "aliases": aliases,
    }

    json.dump(sample, output, ensure_ascii=False)
    output.write("\n")


output = open(RAW / "kbs"/ "qids_with_cui_output.jsonl", "w", encoding="utf-8")
workers = 100 
nloops = 731418 / workers # no of samples / no of workers (used for loop tqdm)

with open(RAW / "kbs" /"qids_with_cui.csv") as csvfile:
    reader = csv.reader(csvfile, delimiter=",")
    next(reader)

    for batch in tqdm(iter(lambda: list(islice(reader, workers)), []), total=nloops):
        batch = [qid[0].split("/")[-1] for qid in batch]

        with ThreadPoolExecutor(max_workers=workers) as pool:
            results = list(pool.map(get_entity_dict_from_api, batch))

        for qid, q_dict in zip(batch, results):
            prepare_and_dump(qid, q_dict)

output.close()

##### Step 1.2.1

In [None]:
kb = []
with open(RAW / "kbs"/ "qids_with_cui_output.jsonl", "r", encoding="utf-8") as f:
    for line in f:

        entry = json.loads(line)
        for cui in set(entry["cui"]):
            # if entry["label"]:
                kb.append({"concept_id": cui, "aliases": entry["aliases"], "canonical_name": entry["label"], "definition": entry["description"]})



with open(PROCESSED / "kbs" / 'kb_from_wikidata_sparql.jsonl', 'w', encoding="utf-8") as outfile:
    for entry in kb:
        json.dump(entry, outfile, ensure_ascii=False)
        outfile.write('\n')

#### Step 1.3





Once the KB is prepared, artifacts can be prepared using the following command from the root of project. It will make ANN index for the KB. 



```bash
python scripts/create_linker.py --kb_path "data/processed/kbs/kb_from_sapbert.jsonl" --output_path "artifacts/sapbert"
```

OR 

```bash
python scripts/create_linker.py --kb_path "data/processed/kbs/kb_from_wikidata_sparql.jsonl" --output_path "artifacts/sparql"
```


## 2: Entity Linking

In [None]:
from tqdm import tqdm

import spacy

from scispacy.linking import *

from scispacy.candidate_generation import DEFAULT_PATHS, DEFAULT_KNOWLEDGE_BASES
from scispacy.candidate_generation import (
    CandidateGenerator,
    LinkerPaths
)

from scispacy.linking_utils import KnowledgeBase

In [None]:
# Uncomment for UMLS KB 
 
CustomLinkerPaths_2020AA = LinkerPaths(
    ann_index="artifacts/sapbert/nmslib_index.bin",  # noqa
    tfidf_vectorizer="artifacts/sapbert/tfidf_vectorizer.joblib",  # noqa
    tfidf_vectors="artifacts/sapbert/tfidf_vectors_sparse.npz",  # noqa
    concept_aliases_list="artifacts/sapbert/concept_aliases.json",  # noqa
)

KB_PATH = str(PROCESSED / "kbs" / "kb_from_sapbert.jsonl")

In [None]:
# # Uncomment for WIKIdata Sparql KB 

# CustomLinkerPaths_2020AA = LinkerPaths(
#     ann_index="artifacts/sparql/nmslib_index.bin",  # noqa
#     tfidf_vectorizer="artifacts/sparql/tfidf_vectorizer.joblib",  # noqa
#     tfidf_vectors="artifacts/sparql/tfidf_vectors_sparse.npz",  # noqa
#     concept_aliases_list="artifacts/sparql/concept_aliases.json",  # noqa
# )

# KB_PATH = str(PROCESSED / "kbs" / "kb_from_wikidata_sparql.jsonl")

In [None]:
class UMLS2020KnowledgeBase(KnowledgeBase):
    def __init__(
        self,
        file_path: str = KB_PATH
    ):
        super().__init__(file_path)

DEFAULT_PATHS["umls2020"] = CustomLinkerPaths_2020AA
DEFAULT_KNOWLEDGE_BASES["umls2020"] = UMLS2020KnowledgeBase

In [None]:
scispacy_linker = CandidateGenerator(name="umls2020")

Sample

In [None]:
# import sys

# !{sys.executable} -m pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.3/en_core_sci_sm-0.5.3.tar.gz

In [None]:
# nlp = spacy.load("en_core_sci_sm")
# nlp.add_pipe("scispacy_linker", config={"resolve_abbreviations": False, "linker_name": "umls2020"} )


# doc = nlp("Spinal and bulbar muscular atrophy (SBMA) is an \
#            inherited motor neuron disease caused by the expansion \
#            of a polyglutamine tract within the androgen receptor (AR). \
#            SBMA can be caused by this easily.")


# # Let's look at a random entity!
# entity = doc.ents[0]

# print("Name: ", entity)

# # Each entity is linked to UMLS with a score
# # (currently just char-3gram matching).
# linker = nlp.get_pipe("scispacy_linker")
# for umls_ent in entity._.kb_ents:
    
#     print(linker.kb.cui_to_entity[umls_ent[0]])
#     print("=================")
    

Query dataset:  XLBEL

In [None]:
# XLBEL data from SAPBERT repo
test_queries = []
with open("data/de_1k_test_query.txt", encoding="utf-8") as file:
    for line in file:
        cui, name = line.strip("\n").split("||")
        test_queries.append((cui, name))
        

In [None]:
# utility for evaluation metric
def check_label(golden_cui:str , predicted_cuis:list, k:int ):
    """
    Some composite annotation didn't consider orders
    So, return True if any cui is matched within composite cui (or single cui)
    Otherwise, return False
    """
    result = []
    for predicted_cui in predicted_cuis[:k]:
        ans = len(set(predicted_cui.split("|")).intersection(set(golden_cui.split("|")))) > 0
        result.append(ans)
    # print(k)
    # print(result)

    return any(result)

In [None]:
def link_and_evaluate(test_queries, topk):
    total_entities = 0
    correct_at_1 = 0
    correct_at_2 = 0
    correct_at_5 = 0
    correct_at_40 = 0
    correct_at_60 = 0
    correct_at_80 = 0
    correct_at_100 = 0
    
    for label, text_span in tqdm(test_queries):
        
        candidates = scispacy_linker([text_span], topk)[0]
        sorted_candidates = sorted(
            candidates, reverse=True, key=lambda x: max(x.similarities)
        )
        # print(len(sorted_candidates))
        candidate_ids = [c.concept_id for c in sorted_candidates]
        
        if check_label(golden_cui = label , predicted_cuis= candidate_ids, k=1 ):
            correct_at_1 += 1
        if check_label(golden_cui = label , predicted_cuis= candidate_ids, k=2 ):
            correct_at_2 += 1
        if check_label(golden_cui = label , predicted_cuis= candidate_ids, k=5 ):
            correct_at_5 += 1
        # if label in candidate_ids[:40]:
        #     correct_at_40 += 1
        # if label in candidate_ids[:60]:
        #     correct_at_60 += 1
        # if label in candidate_ids[:80]:
        #     correct_at_80 += 1
        # if label in candidate_ids[:100]:
        #     correct_at_100 += 1

        total_entities += 1

    print("Total entities: ", total_entities)
    print(
        "Correct at 1: ", correct_at_1, "Precision at 1: ", correct_at_1 / total_entities
    )
    print(
        "Correct at 2: ", correct_at_2, "Precision at 2: ", correct_at_2 / total_entities
    )
    print(
        "Correct at 5: ",
        correct_at_5,
        "Precision at 5: ",
        correct_at_5 / total_entities,
    )
    

In [None]:
link_and_evaluate(test_queries = test_queries, topk=10)

Query dataset: WikiMed-DE-BEL

In [None]:
WikiMed_DE_BEL = RAW / "BEL-silver-standard/WikiMed-DE-BEL"

In [None]:
def prepare_query_data(data):
    from collections import defaultdict
    name_cui_map = defaultdict(set)
    for entry in data:
        entry_title = entry["title"]
        entry_cui = entry["cui"]

        if entry_title and entry_cui !="None":
            name_cui_map[entry_title].add(entry_cui)
           
        
        mentions = entry["mentions"]
        for m in mentions:
            mention_title = m["mention"]
            mention_cui = m["cui"]

            if mention_title and mention_cui !="None":
                name_cui_map[mention_title].add(mention_cui)
                
    
    test_queries = [(f"{'|'.join(cuis)}",name) for name, cuis in name_cui_map.items()]
                
    return test_queries

In [None]:
with open(WikiMed_DE_BEL / "train_data_bel.json", "r", encoding="utf-8") as f:
    data = json.loads(f.read())

In [None]:
test_queries = set(prepare_query_data(data))
len(test_queries)

In [None]:
link_and_evaluate(test_queries = test_queries, topk=10)

In [None]:
with open(WikiMed_DE_BEL / "dev_data_bel.json", "r", encoding="utf-8") as f:
    data = json.loads(f.read())

In [None]:
test_queries = set(prepare_query_data(data))
len(test_queries)

In [None]:
link_and_evaluate(test_queries = test_queries, topk=10)

In [None]:
with open(WikiMed_DE_BEL / "test_data_bel.json", "r", encoding="utf-8") as f:
    data = json.loads(f.read())

In [None]:
test_queries = set(prepare_query_data(data))
len(test_queries)

In [None]:
link_and_evaluate(test_queries = test_queries, topk=10)