In [69]:
import pandas as pd
import numpy as np
import mlflow
from neo4j import GraphDatabase
from sklearn.metrics import precision_score, recall_score
import matplotlib.pyplot as plt

import ast
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

In [63]:
ncbi_dev_df = pd.read_csv('../data/processed/ncbi_dev.csv')

ncbi_dev_df['SpecificDisease'] = ncbi_dev_df['SpecificDisease'].apply(lambda x: x if x else None)
ncbi_dev_df = ncbi_dev_df.dropna(subset=['SpecificDisease'])

In [64]:
ncbi_dev_df.head()

Unnamed: 0,ID,Raw,DiseaseClass,CompositeMention,SpecificDisease,Modifier
1,8589723,Ovarian cancer risk in BRCA1 carriers is modif...,"['cancers', 'inherited cancer syndrome']",['hereditary breast and ovarian cancer'],"['breast cancer', 'ovarian cancer', 'breast ca...","['Ovarian cancer', 'cancer']"
2,8595416,A novel homeodomain-encoding gene is associate...,,,"['Myotonic dystrophy', 'DM']","['myotonic dystrophy', 'DM', 'DM']"
3,8605116,Germline mutations in the RB1 gene in patients...,,['familial or sporadic bilateral retinoblastoma'],['hereditary retinoblastoma'],
4,8621452,Type II human complement C2 deficiency. Allele...,,,"['Type II human complement C2 deficiency', 'Ty...",
5,8622978,Defective dimerization of von Willebrand facto...,,,"['type IID von Willebrand disease', 'type IID ...","['von Willebrand', 'von Willebrand']"


In [41]:
BAAI_BGE_SMALL_EN_V1_5 = "BAAI/bge-small-en-v1.5"

In [8]:
baai_embed_model = HuggingFaceEmbedding(model_name=BAAI_BGE_SMALL_EN_V1_5)



In [25]:
test = ast.literal_eval(ncbi_dev_df.loc[1]['SpecificDisease'])

In [26]:
test_embed = [baai_embed_model.get_text_embedding(x) for x in test]

In [27]:
test_embed

[[-0.003422154113650322,
  -0.02032032236456871,
  -0.01774859055876732,
  -0.07792509347200394,
  0.006694750394672155,
  0.032258033752441406,
  0.0684860348701477,
  -0.0016329901991412044,
  0.031529124826192856,
  -0.027491429820656776,
  0.035244207829236984,
  -0.06385888159275055,
  -0.010010877624154091,
  -0.012356391176581383,
  -0.03366026282310486,
  0.05220355838537216,
  0.03395463153719902,
  0.009422757662832737,
  -0.1040399894118309,
  0.021456746384501457,
  -0.024837497621774673,
  -0.0112640131264925,
  -0.005289907101541758,
  0.03820544108748436,
  0.05764637142419815,
  -0.0014770017005503178,
  -0.012669562362134457,
  -0.0500645637512207,
  -0.09161483496427536,
  -0.11865147948265076,
  0.024060003459453583,
  -0.044547926634550095,
  0.04131670296192169,
  0.0018130569951608777,
  -0.018222125247120857,
  -0.051625389605760574,
  0.017194285988807678,
  0.03361809626221657,
  0.022005662322044373,
  -0.011509822681546211,
  -0.012212990783154964,
  -0.01376

In [28]:
def get_embeddings(x):
    disease_list = ast.literal_eval(x)
    return [baai_embed_model.get_text_embedding(x) for x in disease_list]

In [65]:
ncbi_dev_df['SpecificDisease_BAAI_BGE_SMALL_EN_V1_5'] = ncbi_dev_df['SpecificDisease'].apply(lambda x: get_embeddings(x))

In [66]:
ncbi_dev_df.head()

Unnamed: 0,ID,Raw,DiseaseClass,CompositeMention,SpecificDisease,Modifier,SpecificDisease_BAAI_BGE_SMALL_EN_V1_5
1,8589723,Ovarian cancer risk in BRCA1 carriers is modif...,"['cancers', 'inherited cancer syndrome']",['hereditary breast and ovarian cancer'],"['breast cancer', 'ovarian cancer', 'breast ca...","['Ovarian cancer', 'cancer']","[[-0.003422154113650322, -0.02032032236456871,..."
2,8595416,A novel homeodomain-encoding gene is associate...,,,"['Myotonic dystrophy', 'DM']","['myotonic dystrophy', 'DM', 'DM']","[[-0.04073653742671013, -0.03838193044066429, ..."
3,8605116,Germline mutations in the RB1 gene in patients...,,['familial or sporadic bilateral retinoblastoma'],['hereditary retinoblastoma'],,"[[0.007475364953279495, 0.03328303247690201, -..."
4,8621452,Type II human complement C2 deficiency. Allele...,,,"['Type II human complement C2 deficiency', 'Ty...",,"[[-0.08893324434757233, 0.01917962171137333, -..."
5,8622978,Defective dimerization of von Willebrand facto...,,,"['type IID von Willebrand disease', 'type IID ...","['von Willebrand', 'von Willebrand']","[[-0.08828882873058319, -0.019136495888233185,..."


In [47]:
# from utils.data_processing import find_similar_diseases

In [35]:
uri = "neo4j://localhost:7999/neo4j"
username = "neo4j"
password = "password"  # replace with your password

driver = GraphDatabase.driver(uri, auth=(username, password))

In [48]:
def find_similar_diseases(query_embedding, embedding_model_name):
    disease_embedding = f"DiseaseEmbedding-{embedding_model_name.replace('.', '_').replace('/', '-')}"

    with driver.session() as session:
        query = """
        MATCH (d:Disease)
        WHERE apoc.map.get(d, $disease_embedding, null) IS NOT NULL
        WITH d, gds.similarity.cosine(apoc.map.get(d, $disease_embedding, null), $query_embedding) AS similarity
        RETURN d.DiseaseName AS name, similarity
        ORDER BY similarity DESC
        LIMIT 5
        """
        result = session.run(query, query_embedding=query_embedding, disease_embedding=disease_embedding)
        
        return [record["name"] for record in result]

In [54]:
query_embedding = baai_embed_model.get_text_embedding(ncbi_dev_df.loc[14]['Raw'])

In [55]:
similar_diseases = find_similar_diseases(query_embedding, "BAAI/bge-small-en-v1.5")

In [56]:
print(similar_diseases)

['Huntington Disease-Like 3', 'Huntington Disease-Like 2', 'Huntington Disease', 'CHA heavy chain disease protein, human', 'SPASTIC ATAXIA 4, AUTOSOMAL RECESSIVE']


In [57]:
print(ncbi_dev_df.loc[14]['SpecificDisease'])

['Huntington disease', 'HD', 'HD', 'HD']


In [67]:
def evaluate_embeddings(embeddings, labels):
    # Assuming embeddings and labels are lists of lists or arrays

    # Calculate cosine similarity between embeddings (example)
    cosine_similarities = [np.dot(emb, lbl) / (np.linalg.norm(emb) * np.linalg.norm(lbl)) for emb, lbl in zip(embeddings, labels)]
    
    # Threshold to determine positive/negative matches (example)
    threshold = 0.5
    predicted_labels = [1 if sim >= threshold else 0 for sim in cosine_similarities]
    true_labels = [1 for _ in labels]  # Assuming all provided labels are true positives for simplicity

    # Calculate precision and recall
    precision = precision_score(true_labels, predicted_labels)
    recall = recall_score(true_labels, predicted_labels)

    metrics = {'precision': precision, 'recall': recall}
    
    return metrics

In [74]:
mlflow.end_run()

In [72]:
# Start an MLflow run
mlflow.start_run()

<ActiveRun: >

In [73]:
# Log model 1 results
model1_metrics = evaluate_embeddings(ncbi_dev_df['SpecificDisease_BAAI_BGE_SMALL_EN_V1_5'], ncbi_dev_df['SpecificDisease'])
mlflow.log_metrics(model1_metrics)
mlflow.log_param("model_name", BAAI_BGE_SMALL_EN_V1_5)

UFuncTypeError: ufunc 'multiply' did not contain a loop with signature matching types (dtype('<U23'), dtype('<U116')) -> None

In [None]:
# Example visualization
labels = ['Model 1', 'Model 2']
precision = [model1_metrics['precision']]
recall = [model1_metrics['recall']]

x = range(len(labels))

plt.figure(figsize=(10, 5))
plt.bar(x, precision, width=0.4, label='Precision', color='b', align='center')
plt.bar(x, recall, width=0.4, label='Recall', color='g', align='edge')

plt.xlabel('Models')
plt.ylabel('Scores')
plt.title('Model Evaluation Metrics')
plt.xticks(x, labels)
plt.legend()
plt.show()