In [1]:
from pathlib import Path
import json


import os
from tqdm import tqdm

# Gated Huggingface models need user access token
# Get access token and specify below: https://huggingface.co/docs/hub/en/security-tokens

ACCESS_TOKEN = ""
os.environ['HF_TOKEN'] = ACCESS_TOKEN

In [2]:
import torch
from torch import Tensor, device

def cos_sim(a: Tensor, b: Tensor) -> Tensor:
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.

    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))

In [None]:
# utility for precision @ k evaluation metric 
def check_label(golden_cui:str , predicted_cuis:list, k:int ):
    """
    Some composite annotation didn't consider orders
    So, return 'True' if any cui is matched within composite cui (or single cui)
    Otherwise, return 'False'
    """
    result = []
    for predicted_cui in predicted_cuis[:k]:
        ans = len(set(predicted_cui.split("|")).intersection(set(golden_cui.split("|")))) > 0
        result.append(ans)
    # print(k)
    # print(result)

    return any(result)

In [None]:
# check_label("A|B" , ["D", "A", "A|B|C"], k=1 )

# Knowledge Base

* data/umls_onto_all_lang_cased_wikimed_only_399931.txt: 
  
 UMLS subset KB from SAPBERT Repo


* "data/qids_with_cui_kb.txt":  
  
WikiData Sparql KB. It is already provided in the repo. However, it can be prepared as follows



In [None]:
# kb = []

# with open( "data/raw/kbs/qids_with_cui_output.jsonl", "r", encoding="utf-8") as f:
#     for line in f:
#         entry = json.loads(line)
#         aliases = entry["aliases"]
#         for cui in set(entry["cui"]):
            
#             if entry['label']:
               
#                 kb.append(f"{cui}||{entry['label']}\n")

#             for alias in aliases:
#                 if alias:
#                     kb.append(f"{cui}||{alias}\n")
                    
# kb = set(kb)

# with open('data/qids_with_cui_kb.txt', 'w', encoding="utf-8") as f:
#     for entry in kb:
#         f.write(entry)

# Query Datasets:

#### XLBEL:

"data/de_1k_test_query.txt": provided by sapbert repo


#### WikiMed-BEL-DE

"data/de_wikimed_bel_train_query.txt"

"data/de_wikimed_bel_dev_query.txt"

"data/de_wikimed_bel_test_query.txt"


The files are already in the Repo. However, they can be prepared in the follwing way from WikiMed-BEL-DE

In [None]:
def prepare_query_data(data):
    from collections import defaultdict
    name_cui_map = defaultdict(set)
    for entry in data:
        entry_title = entry["title"]
        entry_cui = entry["cui"]

        if entry_title and entry_cui !="None":
            name_cui_map[entry_title].add(entry_cui)

        mentions = entry["mentions"]
        for m in mentions:
            mention_title = m["mention"]
            mention_cui = m["cui"]

            if mention_title and mention_cui !="None":
                name_cui_map[mention_title].add(mention_cui)
                
    
    test_queries = [f"{'|'.join(cuis)}||{name}\n" for name, cuis in name_cui_map.items()] 
    # some names have more than one cui
    return test_queries

In [None]:
# import json

# # train
# with open( "data/raw/BEL-silver-standard/WikiMed-DE-BEL/train_data_bel.json", "r", encoding="utf-8") as f:
#     data = json.loads(f.read())

# test_queries = set(prepare_query_data(data))

# with open('data/de_wikimed_bel_train_query.txt', 'w', encoding="utf-8") as f:
#     for entry in test_queries:
#         f.write(entry)
        
# # dev
# with open( "data/raw/BEL-silver-standard/WikiMed-DE-BEL/dev_data_bel.json", "r", encoding="utf-8") as f:
#     data = json.loads(f.read())

# test_queries = set(prepare_query_data(data))

# with open('data/de_wikimed_bel_dev_query.txt', 'w', encoding="utf-8") as f:
#     for entry in test_queries:
#         f.write(entry)


# # test
# with open( "data/raw/BEL-silver-standard/WikiMed-DE-BEL/test_data_bel.json", "r", encoding="utf-8") as f:
#     data = json.loads(f.read())

# test_queries = set(prepare_query_data(data))

# with open('data/de_wikimed_bel_test_query.txt', 'w', encoding="utf-8") as f:
#     for entry in test_queries:
#         f.write(entry)



# Select Knowledge Base

All the Query datasets will be linked to the selected KB



In [None]:
# load KB

kb_path = "data/qids_with_cui_kb.txt" # WikiData Sparql KB
# kb_path = "data/umls_onto_all_lang_cased_wikimed_only_399931.txt" # UMLS subset KB from SAPBERT Repo

kb_tuples = []

with open(kb_path, encoding="utf-8") as file:
    for line in file:
        cui, desc = line.strip("\n").split("||")
        kb_tuples.append((cui, desc))

kb_entites = [i[1] for i in kb_tuples]
kb_cui_desc_map = {i:e for i, e in enumerate(kb_tuples)} # indexing dict is faster than list


# Entity Linking using  Embedding 

All the embedding models follow the following steps to link a mention to KB

* Make embedding of all the entites in the KB and save as FAISS index.

* Make embedding of a each mention in Query dataset.

* Find K Nearest Neigbours by comparing a query embedding with KB embeddings using cosine similarity.



In [None]:
def evaluate_on_fly( retrieved_kb_ids,  queries_cui_desc_map , kb_cui_desc_map ):
    
        top_k = 5

        total_entities = 0
        correct_at_1 = 0
        correct_at_2 = 0
        correct_at_5 = 0
       
        queries_results = []
        
        for i, sample_kb_ids in tqdm(enumerate(retrieved_kb_ids), total = len(retrieved_kb_ids)):
                
                golden_cui, gold_mention = queries_cui_desc_map[i]
                candidate_ids = [kb_cui_desc_map[idx][0] for idx in sample_kb_ids]
                
                np_candidates = [kb_cui_desc_map[idx] for idx in sample_kb_ids]
                
                
                dict_candidates = []
                for np_candidate in np_candidates:
                        dict_candidates.append({
                                'name':np_candidate[1],
                                'labelcui':np_candidate[0],
                                'match':check_label(np_candidate[0], [golden_cui], k=1)
                        })
                queries_results.append({
                        'mention':gold_mention,
                        'golden_cui':golden_cui, # golden_cui can be composite cui
                        'candidates':dict_candidates
                })  
                

                if check_label(golden_cui = golden_cui , predicted_cuis= candidate_ids, k=1 ):
                        correct_at_1 += 1
                if check_label(golden_cui = golden_cui , predicted_cuis= candidate_ids, k=5 ):
                        correct_at_5 += 1

                total_entities += 1

        print("Total entities: ", total_entities)
        print(
                "Correct at 1: ", correct_at_1, "Recall at 1: ", correct_at_1 / total_entities
        )
       
        print(
                "Correct at 5: ",
                correct_at_5,
                "Recall at 5: ",
                correct_at_5 / total_entities,
                )
        return queries_results



In [None]:
import faiss

class FAISSIndex:
    def __init__(self, embedding_size) -> None:
        self.index = faiss.index_factory(embedding_size, "Flat", faiss.METRIC_INNER_PRODUCT)
        

    def create_index(self, embeddings):
        
        faiss.normalize_L2(embeddings)
        self.index.add(embeddings)
        
        return self.index

    def save_index(self, save_path):
        faiss.write_index(self.index, save_path)

    def load_index(self, load_path):
        self.index = faiss.read_index(load_path)

    def index_look_up(self, query_embedding, top_k):
        faiss.normalize_L2(query_embedding)
        distances, ids = self.index.search(query_embedding, top_k)
        ids = ids.tolist()
        
        return (ids, distances)
        
        

In [3]:
# import nmslib

# class ANNIndex:
#     def __init__(self, matrix_is_sparse=False) -> None:
#         self.matrix_is_sparse = matrix_is_sparse
        
#         if self.matrix_is_sparse:
#             # initialize a new index, using a HNSW index on Cosine Similarity
#             self.index = nmslib.init(
#                 method="hnsw",
#                 space="cosinesimil_sparse",
#                 data_type=nmslib.DataType.SPARSE_VECTOR,
#             )
#         else:
#             self.index = nmslib.init(method="hnsw", space="cosinesimil")

#     def create_index(self, embeddings):
        
#         self.index.addDataPointBatch(embeddings)
        
#         return self.index

#     def save_index(self, save_path):
#         self.index.createIndex(
#             {"M": 50, "efConstruction": 200, "post": 2}, print_progress=True
#         )
        
#         if self.matrix_is_sparse:
#             self.index.saveIndex(str(save_path), save_data=True)
#         else:
            
#             self.index.saveIndex(str(save_path), save_data=False)

#     def load_index(self, load_path):
#         if self.matrix_is_sparse:
#             self.index = nmslib.init(
#                 method="hnsw",
#                 space="cosinesimil_sparse",
#                 data_type=nmslib.DataType.SPARSE_VECTOR,
#             )
#             self.index.loadIndex(str(load_path), load_data=True)
#         else:
#             self.index = nmslib.init(method="hnsw", space="cosinesimil")
#             self.index.loadIndex(str(load_path))

#     def index_look_up(self, query_embedding, top_k):
#         # import pdb; pdb.set_trace()
#         temp = self.index.knnQueryBatch(
#             query_embedding, k=top_k, 
#             num_threads=0
#         )
#         ids = [i[0].tolist() for i in temp]
#         # distances = [i[1].tolist() for i in temp]
#         return (ids, [])

In [4]:
from pathlib import Path
import os  
ANN_INDEXES = Path('artifacts/ann_indexes')
ANN_INDEXES.mkdir(parents=True, exist_ok=True)

RESULTS = Path('artifacts/results')
RESULTS.mkdir(parents=True, exist_ok=True)

## Model 1: jina-embeddings-v2-base-de

A German/English bilingual text embedding model supporting 8192 sequence length. 

For reference, the max length is set to 25 in the SAPBERT original repo by authors. We restrict the length to 40 for all models as german text is usually longer. 

In [None]:
from transformers import AutoModel
from torch.utils.data import Dataset, DataLoader

In [None]:
load_saved_index = False # load saved index
batch_size = 8  # batch size for embedding encoder
data_loader_batch_size = 50000
index_path = ANN_INDEXES / f"{os.path.splitext(os.path.basename(kb_path))[0]}_jinaai.bin"
index_path = str(index_path)


# ann_index = ANNIndex(matrix_is_sparse=False)
ann_index = FAISSIndex(embedding_size=768)

model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-de", trust_remote_code=True, )
# set lenth in encodes


# from sentence_transformers import SentenceTransformer
# model = SentenceTransformer("jina")
# model.max_seq_length = 40

if load_saved_index:
    ann_index.load_index(index_path)
else: 
    print("Creating Index")
    loader = DataLoader(kb_entites, shuffle=False, batch_size=data_loader_batch_size,)
    
    for kb_entity_batch in tqdm(loader):
        kb_embeddings = model.encode(kb_entity_batch, show_progress_bar =True,
                                     batch_size =batch_size , device = "cuda",
                                    #  max_length=40,
                                     )

        ann_index.create_index(kb_embeddings)
        
    print("Index Created")
    ann_index.save_index(save_path= str(index_path))
  
    
# Query Datasets
for query_path in [
                    "data/de_1k_test_query.txt",
                   "data/de_wikimed_bel_train_query.txt",
                   "data/de_wikimed_bel_dev_query.txt",
                   "data/de_wikimed_bel_test_query.txt",
                   ]:

    # load data
    query_tuples = []
    with open(query_path, encoding="utf-8") as file:
        for line in file:
            cui, desc = line.strip("\n").split("||")
            query_tuples.append((cui, desc))

    queries = [i[1] for i in query_tuples]
    queries_cui_desc_map = {i:e for i, e in enumerate(query_tuples)}

    # embed
    query_embeddings = model.encode(queries, show_progress_bar =True, batch_size =batch_size , device = "cuda", 
                                    max_length=40,
                                    )
    
    
    print("Index Lookup started")
    retrieved_kb_ids, _ = ann_index.index_look_up(query_embeddings, top_k=20) # batch lookup is faster
    
    # query_candidate_ids = []
    # for sample in retrieved_kb_ids:
    #     query_candidate_ids.append([kb_cui_desc_map[idx] for idx in sample])
                    
    print("======================")
    print(query_path)
    evaluate_on_fly( retrieved_kb_ids,  queries_cui_desc_map , kb_cui_desc_map )
    print("======================")
   
    

del model
del query_embeddings

## Model 2: SAPBERT

We use "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR" model. There is large version "cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR-large" also which could be tried

In [None]:
import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def gererate_embedding(model, tokenizer , all_names, batch_size):
    bs = batch_size # batch size during inference
    all_embs = []
    for i in tqdm(np.arange(0, len(all_names), bs)):
        toks = tokenizer.batch_encode_plus(all_names[i:i+bs],
                                        padding="max_length",
                                        max_length=40,
                                        truncation=True,
                                        return_tensors="pt")
        toks_cuda = {}
        for k,v in toks.items():
            toks_cuda[k] = v.cuda()
            
        # cls     
        cls_rep = model(**toks_cuda)[0][:,0,:] # use CLS representation as the embedding
        all_embs.append(cls_rep.cpu().detach().numpy())
        
        # mean 
        # model_output = model(**toks_cuda) 
        # all_embs.append(mean_pooling(model_output, toks_cuda['attention_mask']).cpu().detach().numpy())
      
        
        
    all_embs = np.concatenate(all_embs, axis=0)
    return all_embs


In [None]:
load_saved_index = False
batch_size = 32  # batch size for embedding encoder
data_loader_batch_size = 50000
index_path = ANN_INDEXES / f"{os.path.splitext(os.path.basename(kb_path))[0]}_sapbert.bin"
index_path = str(index_path)

results_path = RESULTS / f"{os.path.splitext(os.path.basename(index_path))[0]}"
results_path.mkdir(parents=True, exist_ok=True)


# ann_index = ANNIndex(matrix_is_sparse=False)
ann_index = FAISSIndex(embedding_size=768)

tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR",
                                          use_fast=True ,
                                        #   local_files_only=True
                                          )
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-UMLS-2020AB-all-lang-from-XLMR",
                                #   local_files_only=True,
                            #    torch_dtype=torch.float16
                                  ).cuda()




if load_saved_index:
    ann_index.load_index(index_path)
else:
    print("Creating Index")
    loader = DataLoader(kb_entites, shuffle=False, batch_size=data_loader_batch_size,)
    
    for kb_entity_batch in tqdm(loader):
        kb_embeddings = gererate_embedding(model, tokenizer, kb_entity_batch, batch_size=batch_size)
        

        ann_index.create_index(kb_embeddings)
    
    print("Index Created")
    ann_index.save_index(save_path= str(index_path))
  
    
# Query Datasets
for query_path in [
                    "data/de_1k_test_query.txt",
                   "data/de_wikimed_bel_train_query.txt",
                   "data/de_wikimed_bel_dev_query.txt",
                   "data/de_wikimed_bel_test_query.txt",
                   ]:
   

    # load data
    query_tuples = []
    with open(query_path, encoding="utf-8") as file:
        for line in file:
            cui, desc = line.strip("\n").split("||")
            query_tuples.append((cui, desc))

    queries = [i[1] for i in query_tuples]
    queries_cui_desc_map = {i:e for i, e in enumerate(query_tuples)}

    # embed
    query_embeddings = gererate_embedding(model, tokenizer, queries, batch_size=batch_size)
    
    print("Index Lookup started")
    retrieved_kb_ids, _ = ann_index.index_look_up(query_embeddings, top_k=100) # batch lookup is faster
    
  
                    
    print("======================")
    print(query_path )
    results = evaluate_on_fly( retrieved_kb_ids,  queries_cui_desc_map , kb_cui_desc_map )
    
    with open(results_path / f"{os.path.splitext(os.path.basename(query_path))[0]}.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False)
        
    print("======================")
   
    

del model
del tokenizer

## Model 3: BGE M3

https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/BGE_M3

A Multilingual model which can generate dense, sparse, and [colbert style](https://til.simonwillison.net/llms/colbert-ragatouille) embeddings. We only use dense embeddings.


It can simultaneously perform the three common retrieval functionalities of embedding model: dense retrieval, multi-vector retrieval (colbert style), and sparse retrieval. This could be tried if GPU permits.



In [None]:
load_saved_index = False
batch_size = 16  # batch size for embedding encoder
data_loader_batch_size = 50000
index_path = ANN_INDEXES / f"{os.path.splitext(os.path.basename(kb_path))[0]}_bge_m3.bin"
index_path = str(index_path)

results_path = RESULTS / f"{os.path.splitext(os.path.basename(index_path))[0]}"
results_path.mkdir(parents=True, exist_ok=True)

# ann_index = ANNIndex(matrix_is_sparse=False)
ann_index = FAISSIndex(embedding_size=1024)

from FlagEmbedding import BGEM3FlagModel

model = BGEM3FlagModel('BAAI/bge-m3',  use_fp16=True, device="cuda"  )




if load_saved_index:
    ann_index.load_index(index_path)
else:
    print("Creating Index")
    loader = DataLoader(kb_entites, shuffle=False, batch_size=data_loader_batch_size,)
    
    for kb_entity_batch in tqdm(loader):
        kb_embeddings = model.encode(kb_entity_batch, return_dense=True, return_sparse=False, 
                             return_colbert_vecs=False, 
                             batch_size=batch_size, max_length=40 )

        kb_embeddings = kb_embeddings['dense_vecs'].astype("float32")
        # import pdb;pdb.set_trace()

        ann_index.create_index(kb_embeddings)
    
    print("Index Created")
    ann_index.save_index(save_path= str(index_path))
  
    
# Query Datasets
for query_path in [
                    "data/de_1k_test_query.txt",
                   "data/de_wikimed_bel_train_query.txt",
                   "data/de_wikimed_bel_dev_query.txt",
                   "data/de_wikimed_bel_test_query.txt",
                   ]:

    # load data
    query_tuples = []
    with open(query_path, encoding="utf-8") as file:
        for line in file:
            cui, desc = line.strip("\n").split("||")
            query_tuples.append((cui, desc))

    queries = [i[1] for i in query_tuples]
    queries_cui_desc_map = {i:e for i, e in enumerate(query_tuples)}

    # embed
    query_embeddings = model.encode(queries, return_dense=True, return_sparse=False,
                                    return_colbert_vecs=False, batch_size=batch_size, 
                                    max_length=40 )
    
    query_embeddings = query_embeddings['dense_vecs'].astype("float32")
    
    print("Index Lookup started")
    retrieved_kb_ids, _ = ann_index.index_look_up(query_embeddings, top_k=100) # batch lookup is faster
    
    
                    
    print("======================")
    print(query_path)
    results = evaluate_on_fly( retrieved_kb_ids,  queries_cui_desc_map , kb_cui_desc_map )
    with open(results_path / f"{os.path.splitext(os.path.basename(query_path))[0]}.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False)
    print("======================")
   
    

del model
