In [53]:
import pandas as pd
from neo4j import GraphDatabase
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding

import re
import json

In [38]:
BAAI_BGE_SMALL_EN_V1_5 = "BAAI/bge-small-en-v1.5"
LLAMA3 = "llama3"

In [2]:
# Load the embedding model
embed_model = HuggingFaceEmbedding(model_name=BAAI_BGE_SMALL_EN_V1_5)



## ADA

In [14]:
df = pd.read_csv("../data/raw/CTD_diseases.csv", sep=',')

In [15]:
df.head()

Unnamed: 0,DiseaseName,DiseaseID,AltDiseaseIDs,Definition,ParentIDs,TreeNumbers,ParentTreeNumbers,Synonyms,SlimMappings
0,10p Deletion Syndrome (Partial),MESH:C538288,,,MESH:D002872|MESH:D025063,C16.131.260/C538288|C16.320.180/C538288|C23.55...,C16.131.260|C16.320.180|C23.550.210.050.500.500,"Chromosome 10, 10p- Partial|Chromosome 10, mon...",Congenital abnormality|Genetic disease (inborn...
1,13q deletion syndrome,MESH:C535484,,,MESH:D002872|MESH:D025063,C16.131.260/C535484|C16.320.180/C535484|C23.55...,C16.131.260|C16.320.180|C23.550.210.050.500.500,Chromosome 13q deletion|Chromosome 13q deletio...,Congenital abnormality|Genetic disease (inborn...
2,15q24 Microdeletion,MESH:C579849,DO:DOID:0060395,,MESH:D002872|MESH:D008607|MESH:D025063,C10.597.606.360/C579849|C16.131.260/C579849|C1...,C10.597.606.360|C16.131.260|C16.320.180|C23.55...,15q24 Deletion|15q24 Microdeletion Syndrome|In...,Congenital abnormality|Genetic disease (inborn...
3,16p11.2 Deletion Syndrome,MESH:C579850,,,MESH:D001321|MESH:D002872|MESH:D008607|MESH:D0...,C10.597.606.360/C579850|C16.131.260/C579850|C1...,C10.597.606.360|C16.131.260|C16.320.180|C23.55...,,Congenital abnormality|Genetic disease (inborn...
4,"17,20-Lyase Deficiency, Isolated",MESH:C567076,,,MESH:D000312,C12.050.351.875.253.090.500/C567076|C12.200.70...,C12.050.351.875.253.090.500|C12.200.706.316.09...,"17-Alpha-Hydroxylase-17,20-Lyase Deficiency, C...",Congenital abnormality|Endocrine system diseas...


In [15]:
with open('../data/raw/NCBI_corpus/NCBI_corpus_development.txt', 'r', encoding='utf-8') as file:
    content = file.read()

# Find all the category tags using regular expressions
category_tags = re.findall(r'<category="([^"]+)">', content)

# Remove duplicates by converting the list to a set
unique_category_tags = set(category_tags)

In [16]:
for tag in unique_category_tags:
    print(tag)

DiseaseClass
CompositeMention
SpecificDisease
Modifier


- DiseaseClass: This can correspond to the high-level categories in the CTD's MEDIC-Slim, which classify diseases into broad categories such as genetic diseases, neoplasms, etc.
- CompositeMention: This can be related to composite terms in CTD that might combine aspects of multiple diseases or conditions, although CTD primarily focuses on distinct disease terms rather than composite mentions. -> can be excluded from the analysis
- SpecificDisease: These can be mapped directly to specific disease terms in the CTD, which are detailed with their MeSH or OMIM identifiers.
- Modifier: Modifiers like "tumor" or "cancer" can be seen in context with specific diseases in CTD, modifying the understanding or description of the disease, such as in "breast cancer" or "ovarian cancer".

## Neo4j

In [9]:
# Connect to Neo4j
uri = "neo4j://localhost:7999/neo4j"
username = "neo4j"
password = "password"  # replace with your password

driver = GraphDatabase.driver(uri, auth=(username, password))


In [None]:
# Function to create disease nodes
def create_disease_nodes(tx, disease):
    tx.run("""
        MERGE (d:Disease {DiseaseID: $DiseaseID})
        SET d.DiseaseName = $DiseaseName, d.AltDiseaseIDs = $AltDiseaseIDs,
            d.Definition = $Definition, d.TreeNumbers = $TreeNumbers,
            d.ParentTreeNumbers = $ParentTreeNumbers, d.Synonyms = $Synonyms,
            d.SlimMappings = $SlimMappings
    """, 
    DiseaseID=disease['DiseaseID'],
    DiseaseName=disease['DiseaseName'],
    AltDiseaseIDs=disease['AltDiseaseIDs'],
    Definition=disease['Definition'],
    TreeNumbers=disease['TreeNumbers'],
    ParentTreeNumbers=disease['ParentTreeNumbers'],
    Synonyms=disease['Synonyms'],
    SlimMappings=disease['SlimMappings'])

In [19]:
# Function to create hierarchical relationships
def create_hierarchy(tx, disease):
    if pd.notna(disease['ParentIDs']):
        parent_ids = disease['ParentIDs'].split('|')
        for parent_id in parent_ids:
            tx.run("""
                MATCH (d:Disease {DiseaseID: $DiseaseID})
                MATCH (p:Disease {DiseaseID: $ParentID})
                MERGE (d)-[:SUB_CATEGORY_OF]->(p)
            """, DiseaseID=disease['DiseaseID'], ParentID=parent_id)

In [44]:
# Function to get disease descriptions
def get_disease_descriptions(tx):
    result = tx.run("""
        MATCH (d:Disease) 
        RETURN d.DiseaseID AS DiseaseID, d.DiseaseName AS DiseaseName, d.Definition AS Definition
    """)
    return result.data()

In [28]:
# Function to update disease embeddings
def update_disease_embeddings(tx, disease_id, embedding, embedding_model_name):
    disease_embedding = f"DiseaseEmbedding-{embedding_model_name.replace('.', '_').replace('/', '_')}"

    query = """
        MATCH (d:Disease {DiseaseID: $DiseaseID})
        CALL apoc.create.setProperty(d, $disease_embedding, $embedding)
        YIELD node
        RETURN node;
    """

    tx.run(query, DiseaseID=disease_id, embedding=embedding, disease_embedding=disease_embedding)

In [17]:
with driver.session() as session:
    for _, row in df.iterrows():
        session.write_transaction(create_disease_nodes, row)

In [20]:
# Create hierarchical relationships
with driver.session() as session:
    for _, row in df.iterrows():
        session.write_transaction(create_hierarchy, row)

In [32]:
with driver.session() as session:
    disease_descriptions = session.execute_read(get_disease_descriptions)

In [35]:
# Generate embeddings for the disease data
embeddings = []
for record in disease_descriptions:
    disease_id = record['DiseaseID']
    disease_name = record['DiseaseName']
    
    name_embedding = embed_model.get_text_embedding(disease_name)
    
    embeddings.append((disease_id, name_embedding))

In [36]:
with driver.session() as session:
    for disease_id, embedding in embeddings:
        session.write_transaction(update_disease_embeddings, disease_id, embedding, BAAI_BGE_SMALL_EN_V1_5)

In [None]:
# Close the driver connection
driver.close()

In [11]:
def find_similar_diseases(query_embedding):
    with driver.session() as session:
        query = """
        MATCH (d:Disease)
        WHERE d.DiseaseEmbedding IS NOT NULL
        WITH d, gds.similarity.cosine(d.DiseaseEmbedding, $query_embedding) AS similarity
        RETURN d.DiseaseName AS name, similarity
        ORDER BY similarity DESC
        LIMIT 5
        """
        result = session.run(query, query_embedding=query_embedding)
        return [record["name"] for record in result]

In [12]:
query_embedding = embed_model.get_text_embedding("breast carcinoma")
similar_diseases = find_similar_diseases(query_embedding)
print(similar_diseases)

['Carcinoma', 'Carcinoma, Ductal', 'Carcinoma, Ductal, Breast', 'Breast Neoplasms', 'Breast Carcinoma In Situ']


In [17]:
test_query_2 = embed_model.get_text_embedding("Type II human complement C2 deficiency. Allele-specific amino acid substitutions (Ser189 --> Phe; Gly444 --> Arg) cause impaired C2 secretion.")
similar_diseases_2 = find_similar_diseases(test_query_2)
print(similar_diseases_2)

['Complement Component 3 Deficiency, Autosomal Recessive', 'Complement Component C1s Deficiency', 'COMPLEMENT COMPONENT 8 DEFICIENCY, TYPE II', 'COMPLEMENT COMPONENT 2 DEFICIENCY', 'COMPLEMENT COMPONENT C1r/C1s DEFICIENCY']


## Ollama embeddings

In [40]:
ollama3_embedding = OllamaEmbedding(
    model_name="llama3",
    base_url="http://localhost:11434",
    ollama_additional_kwargs={"mirostat": 0},
)

In [None]:
disease_descriptions = session.execute_read(get_disease_descriptions)
llama3_embeddings = []

for record in disease_descriptions:
    disease_id = record['DiseaseID']
    disease_name = record['DiseaseName']
    name_embedding = ollama3_embedding.get_text_embedding(disease_name)
    
    llama3_embeddings.append((disease_id, name_embedding))

In [48]:
len(llama3_embeddings)

13298

In [50]:
llama3_embeddings[0]

('MESH:C538288',
 [1.537963628768921,
  -6.55998420715332,
  0.732938826084137,
  -0.5272596478462219,
  0.8210824131965637,
  -0.783358097076416,
  -1.6790488958358765,
  2.2116734981536865,
  1.1363089084625244,
  -2.472944736480713,
  -0.863217830657959,
  -1.3022089004516602,
  -0.6505035758018494,
  -0.5636776089668274,
  -1.588836669921875,
  4.209712505340576,
  -0.9245606660842896,
  0.049492813646793365,
  0.0014104063156992197,
  2.365940809249878,
  2.3753600120544434,
  -2.621748447418213,
  -0.9279527068138123,
  -0.5072973966598511,
  -4.370858192443848,
  -2.5936226844787598,
  3.804238796234131,
  -0.08809918165206909,
  -0.28972694277763367,
  -4.29339075088501,
  -0.509268045425415,
  3.8331706523895264,
  3.486872673034668,
  1.1753950119018555,
  5.642094135284424,
  -2.0663490295410156,
  1.3144869804382324,
  -0.9135729074478149,
  0.7163174152374268,
  -2.8381168842315674,
  3.105882167816162,
  -1.7675647735595703,
  -0.28869831562042236,
  -0.35513678193092346,

In [56]:
with driver.session() as session:
    for id, name_embedding in llama3_embeddings:
        session.execute_write(update_disease_embeddings, id, [float(x) for x in name_embedding], LLAMA3)
    

In [42]:
query_embedding = ollama3_embedding.get_text_embedding("Where is blue?")
print(query_embedding)

[-1.7615643739700317, -2.681959867477417, 4.00814962387085, -0.7706565856933594, -1.8953564167022705, 0.05579962953925133, -2.0483341217041016, 3.1704769134521484, 0.860783040523529, 0.5830509662628174, 0.7892054915428162, -0.28447243571281433, -3.0988128185272217, -0.699832022190094, -2.0504283905029297, 0.6201938986778259, -1.7073917388916016, 0.6968017220497131, -3.1414942741394043, 1.4458979368209839, -1.2843347787857056, -2.3614113330841064, 4.591588020324707, 3.3754022121429443, -3.1164793968200684, -1.1824647188186646, 1.3905484676361084, -6.002542495727539, 0.6711130738258362, -2.3338406085968018, 0.8518390655517578, 1.2921946048736572, -0.9943997263908386, -0.791110634803772, 8.179473876953125, 0.8290021419525146, -1.2397180795669556, 2.499133825302124, 2.0315685272216797, 2.9433538913726807, 2.0506622791290283, 0.39730241894721985, 1.8135545253753662, -0.6566274166107178, 0.8271182775497437, 0.783367395401001, 2.329834461212158, -0.22141793370246887, 0.8954024910926819, -2.11

## NCBI embeddings

In [60]:
ncbi_dev_df = pd.read_csv("../data/raw/NCBI_corpus/NCBI_corpus_development.txt", sep="\t", header=None)

In [61]:
ncbi_dev_df.head()

Unnamed: 0,0,1,2
0,8589722,BRCA1 is secreted and exhibits properties of a...,Germline mutations in BRCA1 are responsible fo...
1,8589723,"<category=""Modifier"">Ovarian cancer</category>...",Women who carry a mutation in the BRCA1 gene (...
2,8595416,A novel homeodomain-encoding gene is associate...,"<category=""SpecificDisease"">Myotonic dystrophy..."
3,8605116,Germline mutations in the RB1 gene in patients...,We have analyzed the 27 exons and the promoter...
4,8621452,"<category=""SpecificDisease"">Type II human comp...","<category=""SpecificDisease"">Type II complement..."


## Pinecode

In [37]:
from pinecone import Pinecone, ServerlessSpec

pc = Pinecone(api_key="c516db64-8506-47f8-ba37-3d1e1e389d81")
pc.create_index(
    name="quickstart",
    dimension=384, # Replace with your model dimensions
    metric="cosine", # Replace with your model metric
    spec=ServerlessSpec(
        cloud="aws",
        region="us-east-1"
    ) 
)

In [38]:
# pc.init(api_key="c516db64-8506-47f8-ba37-3d1e1e389d81", environment="us-east-1")
index = pc.Index("quickstart")

In [39]:
# Function to get disease embeddings
def get_disease_embeddings(tx, embedding_model_name):
    dynamic_field_name = f"DiseaseEmbedding{embedding_model_name}"

    result = tx.run("""
        MATCH (d:Disease)
        RETURN d.DiseaseID AS DiseaseID, d.{dynamic_field_name} AS NameEmbedding
    """)
    return result.data()

In [None]:
# Retrieve embeddings from Neo4j
with driver.session() as session:
    disease_embeddings = session.execute_read(get_disease_embeddings, BAAI_BGE_SMALL_EN_V1_5)

In [None]:
# Prepare and upload to Pinecone
vectors = []
for record in disease_embeddings:
    disease_id = record['DiseaseID']
    name_embedding = record['NameEmbedding']
    
    if name_embedding:
        vectors.append({
            'id': f"{disease_id}-name",
            'values': name_embedding
        })

index.upsert(vectors=vectors)