In [2]:
import pandas as pd
import torch
import faiss
import pickle
import numpy as np
import gc
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import torch.nn.functional as F


In [3]:
# Filenames

path__ = "/nfs/production/literature/santosh_tirunagari/BACKUP/work/github/source_data/dictionaries/"
# path__ = "/homes/ines"
INPUT_FILENAME = "/nfs/production/literature/santosh_tirunagari/BACKUP/work/github/source_data/knowledge_base/bao/BAO.csv"
OUTPUT_PICKLE_FILENAME = path__ + "bao.pkl"
OUTPUT_LIST = path__ + "bao_list.txt"
FAISS_INDEX_FILENAME = path__ + "bao_terms.index"
# OUTPUT_INDEXED_TERMS_FILENAME = path__+"work/github/ML_annotations/normalisation/dictionary/bao_indexed_terms.pkl"


In [5]:
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bioformers/bioformer-8L')
model = BertModel.from_pretrained('bioformers/bioformer-8L')

def get_bert_embedding(text):
    # Tokenize input text and convert to tensors
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
    
    # Get hidden states from BERT model
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Extract the embeddings from the last hidden layer
    hidden_states = outputs.last_hidden_state
    
    # Perform mean pooling by averaging token embeddings across all tokens
    mean_embedding = hidden_states.mean(dim=1).squeeze()
    
    return mean_embedding.numpy()

# def get_average_embeddings_batched_transformers(sentences, model_name="tavakolih/all-MiniLM-L6-v2-pubmed-full"):
#     """Return average embeddings for sentences using a Transformers model."""
#     tokenizer = AutoTokenizer.from_pretrained(model_name)
#     model = AutoModel.from_pretrained(model_name)

#     # Tokenize sentences
#     encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

#     # Compute token embeddings
#     with torch.no_grad():
#         model_output = model(**encoded_input)

#     # Perform pooling (mean pooling function is used here)
#     def mean_pooling(model_output, attention_mask):
#         token_embeddings = model_output[0]  # First element of model_output contains token embeddings
#         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#         sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
#         sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
#         return sum_embeddings / sum_mask

#     sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

#     # Normalize embeddings
#     sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

#     return sentence_embeddings

def create_index(embeddings_np, d):
    """Create an accurate Faiss index using L2 distance."""
    index = faiss.IndexFlatL2(d)  # No quantization, exact search
    index.add(embeddings_np)  # Directly add embeddings to the index
    return index




In [6]:
def process_column_content(s):
    """Clean and strip unwanted characters and split by pipe if present."""
    # First, clean the string by removing specific patterns
    cleaned = s.strip().lower()
    
    # Check if the cleaned string contains a pipe symbol and split if it does
    if '|' in cleaned:
        return cleaned.split('|')
    else:
        return cleaned


df = pd.read_csv(INPUT_FILENAME, usecols=['Class ID', 'Preferred Label', 'Synonyms', 'Definitions', 'alternative term'], 
                 sep=',', engine='python', on_bad_lines='skip')


term_to_id = {}
embeddings = []  
indexed_terms = []


flattened_data = []
for _, row in df.iterrows():
    term_id = row['Class ID']
    for col in ['Preferred Label', 'Synonyms', 'Definitions', 'alternative term']:
        term_names = row[col]
        if pd.notnull(term_names):  # Check if the term_name is not NaN
            processed_terms = process_column_content(term_names)
            if isinstance(processed_terms, list):
                for term in processed_terms:
                    flattened_data.append((term_id, term))
            else:
                flattened_data.append((term_id, processed_terms))

# Convert flattened data to a DataFrame for easier manipulation
flattened_df = pd.DataFrame(flattened_data, columns=['Class ID', 'Term Name'])

flattened_df

Unnamed: 0,Class ID,Term Name
0,http://purl.obolibrary.org/obo/CHEBI_50444,adenosine phosphodiesterase inhibitor
1,http://purl.obolibrary.org/obo/CHEBI_131787,dopamine receptor d2 antagonist
2,http://purl.obolibrary.org/obo/CHEBI_131787,d2r antagonist
3,http://purl.obolibrary.org/obo/CHEBI_131787,d2 receptor antagonist
4,http://purl.obolibrary.org/obo/CHEBI_131789,runx1 inhibitor
...,...,...
33353,http://purl.obolibrary.org/obo/DOID_3953,adrenal cancer
33354,http://purl.obolibrary.org/obo/DOID_3953,tumor of the adrenal gland
33355,http://purl.obolibrary.org/obo/DOID_3953,malignant neoplasm of adrenal gland
33356,http://purl.obolibrary.org/obo/DOID_3953,malignant adrenal tumor


In [7]:
# Load the CSV into a DataFrame
df = flattened_df

# Preprocess and collect embeddings
bert_embeddings = []
term_to_id = {}

for _, row in tqdm(df.iterrows(), total=df.shape[0], desc="Processing terms"):
    term_id = row['Class ID']
    term_name = row['Term Name']
    
    term_to_id[term_name] = term_id
    embedding = get_bert_embedding(term_name)
    bert_embeddings.append(embedding)

# Convert list of embeddings to a NumPy array
embeddings_np = np.array(bert_embeddings)
d = embeddings_np.shape[1]  # Dimensionality of embeddings

# Create and save the Faiss index
index = create_index(embeddings_np, d)
faiss.write_index(index, FAISS_INDEX_FILENAME)

# Save term to ID mapping
with open(OUTPUT_PICKLE_FILENAME, "wb") as outfile:
    pickle.dump({"term_to_id": term_to_id}, outfile)


Processing terms:   0%|                             | 0/33358 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Processing terms: 100%|████████████████| 33358/33358 [03:43<00:00, 149.14it/s]


In [13]:

# Write terms to a text file
with open(OUTPUT_LIST, "w") as txt_file:
    for term in term_to_id.keys():
        txt_file.write(term + "\n")

def match_entity(ner_entity):
    # Step 1: TF-IDF matching using 2-grams and 3-grams
    tfidf_vectorizer = TfidfVectorizer(ngram_range=(2, 3))
    tfidf_matrix = tfidf_vectorizer.fit_transform(df['Term Name'])
    query_vector = tfidf_vectorizer.transform([ner_entity])
    similarities = cosine_similarity(query_vector, tfidf_matrix)
    best_match_idx = similarities.argmax()
    best_match_score = similarities.max()
    
    if best_match_score > 0.0:
        matched_term_name = df.iloc[best_match_idx]['Term Name']
        matched_term_id = df.iloc[best_match_idx]['Class ID']
        return matched_term_id, matched_term_name, best_match_score
    
    # Step 2: BERT + FAISS if no good TF-IDF match
    query_embedding = get_bert_embedding(ner_entity)
    distances, indices = index.search(np.array([query_embedding]), k=1)
    
    matched_term_name = df.iloc[indices[0][0]]['Term Name']
    matched_term_id = df.iloc[indices[0][0]]['Class ID']
    score = 1 - distances[0][0]
    
    return matched_term_id, matched_term_name, score




In [14]:
# List of NER entities
ner_entities = ["SPR Analysis", "TR-FRET method", "Radioligand displacement"]

# Loop over each NER entity and find the match
for ner_entity in ner_entities:
    term_id, term_name, score = match_entity(ner_entity)
    print(f"NER Entity: '{ner_entity}' -> Matched Term Name: '{term_name}', Term ID: {term_id}, Score: {score}")


NER Entity: 'SPR Analysis' -> Matched Term Name: 'bradford protein assay', Term ID: http://www.bioassayontology.org/bao#BAO_0002463, Score: -48.478981018066406
NER Entity: 'TR-FRET method' -> Matched Term Name: 'tr-fret', Term ID: http://www.bioassayontology.org/bao#BAO_0000004, Score: 1.0
NER Entity: 'Radioligand displacement' -> Matched Term Name: 'potentials of mean force scoring function', Term ID: http://www.bioassayontology.org/bao#BAO_0002404, Score: -91.42115783691406


In [None]:
# Free up memory after use
del embeddings_np
gc.collect()