In [94]:
import pandas as pd
import os
from neo4j import GraphDatabase
from dotenv import load_dotenv

train_dataset = pd.read_csv('symptoms_train.csv')
train_dataset.head()


Unnamed: 0.1,Unnamed: 0,Symptoms,id,text,long_texts,short_texts,discharge_summary,short_codes
0,0,"['Substernal Chest Pain', 'Sharp Pain', 'Cresc...",147171,CHIEF COMPLAINT: Substernal Chest Pain\n\nPRES...,Acute myocardial infarction of other anterior ...,"AMI anterior wall, init,Ac systolic hrt failur...",Admission Date: [**2102-9-26**] ...,4101142821997142714140142804273145829
1,1,['Back pain'],199961,CHIEF COMPLAINT: \n\nPRESENT ILLNESS: The pati...,"Ankylosing spondylitis,Hypertensive chronic ki...","Ankylosing spondylitis,Hyp kid NOS w cr kid V,...",Admission Date: [**2115-6-29**] Dischar...,"7200,40391,8052,8471,E8859,78057,2859,25060"
2,2,"['Shortness of breath', 'Cough', 'Occasional n...",136812,CHIEF COMPLAINT: \n\nPRESENT ILLNESS: This is ...,Obstructive chronic bronchitis with (acute) ex...,"Obs chr bronc w(ac) exac,Pneumonia, organism N...",Admission Date: [**2106-4-14**] Dischar...,491214862800427894261143889729892449
3,3,"['left arm pain', 'left leg pain', 'pulmonary ...",175700,CHIEF COMPLAINT: s/p rollover MVC with prolong...,"Closed fracture of shaft of fibula with tibia,...","Fx shaft fib w tib-clos,Pneumococcal pneumonia...",Admission Date: [**2159-2-9**] D...,"82322,481,86121,5180,2851,81322,8072,E8160,883..."
4,4,"['Palpitations', 'Shortness of breath', 'Sore ...",193486,CHIEF COMPLAINT: Shortness of breath\n\nPRESEN...,"Other specified cardiac dysrhythmias,End stage...","Cardiac dysrhythmias NEC,End stage renal disea...",Admission Date: [**2136-10-4**] ...,"42789,5856,6822,6164,42830,2761,2869,4160,0411..."


In [95]:
icd_9_dataset = pd.read_csv('D_ICD_DIAGNOSES.csv')
icd_9_dataset.head()

Unnamed: 0,ROW_ID,ICD9_CODE,SHORT_TITLE,LONG_TITLE
0,174,1166,TB pneumonia-oth test,"Tuberculous pneumonia [any form], tubercle bac..."
1,175,1170,TB pneumothorax-unspec,"Tuberculous pneumothorax, unspecified"
2,176,1171,TB pneumothorax-no exam,"Tuberculous pneumothorax, bacteriological or h..."
3,177,1172,TB pneumothorx-exam unkn,"Tuberculous pneumothorax, bacteriological or h..."
4,178,1173,TB pneumothorax-micro dx,"Tuberculous pneumothorax, tubercle bacilli fou..."


In [125]:
preprocessed_df = train_dataset[['Symptoms','id', 'short_codes' ]].copy()
preprocessed_df.head()

Unnamed: 0,Symptoms,id,short_codes
0,"['Substernal Chest Pain', 'Sharp Pain', 'Cresc...",147171,4101142821997142714140142804273145829
1,['Back pain'],199961,"7200,40391,8052,8471,E8859,78057,2859,25060"
2,"['Shortness of breath', 'Cough', 'Occasional n...",136812,491214862800427894261143889729892449
3,"['left arm pain', 'left leg pain', 'pulmonary ...",175700,"82322,481,86121,5180,2851,81322,8072,E8160,883..."
4,"['Palpitations', 'Shortness of breath', 'Sore ...",193486,"42789,5856,6822,6164,42830,2761,2869,4160,0411..."


In [126]:
preprocessed_df['Symptoms'] = preprocessed_df['Symptoms'].apply(lambda x : str(x[1:-1]).split(', '))
preprocessed_df.head()

Unnamed: 0,Symptoms,id,short_codes
0,"['Substernal Chest Pain', 'Sharp Pain', 'Cresc...",147171,4101142821997142714140142804273145829
1,['Back pain'],199961,"7200,40391,8052,8471,E8859,78057,2859,25060"
2,"['Shortness of breath', 'Cough', 'Occasional n...",136812,491214862800427894261143889729892449
3,"['left arm pain', 'left leg pain', 'pulmonary ...",175700,"82322,481,86121,5180,2851,81322,8072,E8160,883..."
4,"['Palpitations', 'Shortness of breath', 'Sore ...",193486,"42789,5856,6822,6164,42830,2761,2869,4160,0411..."


In [127]:
preprocessed_df = preprocessed_df.explode('Symptoms')
preprocessed_df = preprocessed_df.assign(short_codes=preprocessed_df['short_codes'].str.split(',')).explode('short_codes')

preprocessed_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 2020063 entries, 0 to 33683
Data columns (total 3 columns):
 #   Column       Dtype 
---  ------       ----- 
 0   Symptoms     object
 1   id           int64 
 2   short_codes  object
dtypes: int64(1), object(2)
memory usage: 61.6+ MB


In [128]:
preprocessed_df.head()

Unnamed: 0,Symptoms,id,short_codes
0,'Substernal Chest Pain',147171,41011
0,'Substernal Chest Pain',147171,42821
0,'Substernal Chest Pain',147171,9971
0,'Substernal Chest Pain',147171,4271
0,'Substernal Chest Pain',147171,41401


In [129]:
preprocessed_df.Symptoms.nunique()

39994

In [130]:
preprocessed_df.head()

Unnamed: 0,Symptoms,id,short_codes
0,'Substernal Chest Pain',147171,41011
0,'Substernal Chest Pain',147171,42821
0,'Substernal Chest Pain',147171,9971
0,'Substernal Chest Pain',147171,4271
0,'Substernal Chest Pain',147171,41401


In [131]:
preprocessed_df.columns = ['Symptoms', 'id', 'ICD9_CODE']

In [132]:
preprocessed_df = preprocessed_df.reset_index(drop= True)

In [133]:
preprocessed_df.loc[
    preprocessed_df['ICD9_CODE'].str.startswith("V"), 'ICD9_CODE'] = preprocessed_df.ICD9_CODE.apply(
    lambda x: x[:3])
preprocessed_df.loc[
    preprocessed_df['ICD9_CODE'].str.startswith("E"), 'ICD9_CODE'] = preprocessed_df.ICD9_CODE.apply(
    lambda x: x[:3])
preprocessed_df.loc[(~preprocessed_df.ICD9_CODE.str.startswith("E")) & (
    ~preprocessed_df.ICD9_CODE.str.startswith("V")), 'ICD9_CODE'] = preprocessed_df.ICD9_CODE.apply(
    lambda x: x[:2])

In [134]:
preprocessed_df['weight'] = preprocessed_df.groupby(['Symptoms', 'ICD9_CODE'])['Symptoms'].transform('size')

preprocessed_df.head()


Unnamed: 0,Symptoms,id,ICD9_CODE,weight
0,'Substernal Chest Pain',147171,41,6
1,'Substernal Chest Pain',147171,42,7
2,'Substernal Chest Pain',147171,99,4
3,'Substernal Chest Pain',147171,42,7
4,'Substernal Chest Pain',147171,41,6


In [135]:
# preprocessed_df.Category.value_counts()

In [136]:
categories = {
    "00-10": ["00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10"],
    "11-13": ["11", "12", "13"],
    "14-20": ["14", "15", "16", "17", "18", "19", "20"],
    "21-23": ["21", "22", "23"],
    "24-27": ["24", "25", "26", "27"],
    "28": ["28"],
    "29-31": ["29", "30", "31"],
    "32-35": ["32", "33", "34", "35"],
    "36-38": ["36", "37", "38"],
    "39-45": ["39", "40", "41", "42", "43", "44", "45"],
    "46-51": ["46", "47", "48", "49", "50", "51"],
    "52-57": ["52", "53", "54", "55", "56", "57"],
    "58-62": ["58", "59", "60", "61", "62"],
    "63-66": ["63", "64", "65", "66"],
    "67": ["67"],
    "68-70": ["68", "69", "70"],
    "71-73": ["71", "72", "73"],
    "74-75": ["74", "75"],
    "76-77": ["76", "77"],
    "78": ["78"],
    "79": ["79"],
    "80-83": ["80", "81", "82", "83"],
    "84-95": ["84", "85", "86", "87", "88", "89", "90", "91", "92", "93", "94", "95"],
    "96-99": ["96", "97", "98", "99"],
    "V01-V09": ["V01", "V02", "V03", "V04", "V05", "V06", "V07", "V08", "V09"],
    "V10-V19": ["V10", "V11", "V12", "V13", "V14", "V15", "V16", "V17", "V18", "V19"],
    "V20-V29": ["V20", "V21", "V22", "V23", "V24", "V25", "V26", "V27", "V28", "V29"],
    "V30-V39": ["V30", "V31", "V32", "V33", "V34", "V35", "V36", "V37", "V38", "V39"],
    "V40-V49": ["V40", "V41", "V42", "V43", "V44", "V45", "V46", "V47", "V48", "V49"],
    "V50-V59": ["V50", "V51", "V52", "V53", "V54", "V55", "V56", "V57", "V58", "V59"],
    "V60-V69": ["V60", "V61", "V62", "V63", "V64", "V65", "V66", "V67", "V68", "V69"],
    "V70-V82": ["V70", "V71", "V72", "V73", "V74", "V75", "V76", "V77", "V78", "V79", "V80", "V81", "V82"],
    "V83-V84": ["V83", "V84"],
    "V85-V85": ["V85"],
    "V86-V86": ["V86"],
    "V87-V87": ["V87"],
    "V88-V88": ["V88"],
    "V89-V89": ["V89"],
    "V90-V90": ["V90"],
    "V91-V91": ["V91"],
    "E00-E09": ["E00", "E01", "E02", "E03", "E04", "E05", "E06", "E07", "E08", "E09"]
}


In [137]:
icd9_to_category = {code: category for category, codes in categories.items() for code in codes}


preprocessed_df['ICD9_CODE'] = preprocessed_df['ICD9_CODE'].astype(str).str.zfill(2)

preprocessed_df['Category'] = preprocessed_df['ICD9_CODE'].map(icd9_to_category)

preprocessed_df.head()

Unnamed: 0,Symptoms,id,ICD9_CODE,weight,Category
0,'Substernal Chest Pain',147171,41,6,39-45
1,'Substernal Chest Pain',147171,42,7,39-45
2,'Substernal Chest Pain',147171,99,4,96-99
3,'Substernal Chest Pain',147171,42,7,39-45
4,'Substernal Chest Pain',147171,41,6,39-45


In [138]:
preprocessed_df.dropna(subset=['Category'], inplace=True)
preprocessed_df.drop_duplicates(inplace=True)

In [139]:
preprocessed_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 1497629 entries, 0 to 2020062
Data columns (total 5 columns):
 #   Column     Non-Null Count    Dtype 
---  ------     --------------    ----- 
 0   Symptoms   1497629 non-null  object
 1   id         1497629 non-null  int64 
 2   ICD9_CODE  1497629 non-null  object
 3   weight     1497629 non-null  int64 
 4   Category   1497629 non-null  object
dtypes: int64(2), object(3)
memory usage: 68.6+ MB


In [140]:
list(preprocessed_df.ICD9_CODE)[0]

'41'

In [141]:
# preprocessed_df['ICD9_CODE']  = preprocessed_df['ICD9_CODE'].astype('str')

In [142]:
preprocessed_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 1497629 entries, 0 to 2020062
Data columns (total 5 columns):
 #   Column     Non-Null Count    Dtype 
---  ------     --------------    ----- 
 0   Symptoms   1497629 non-null  object
 1   id         1497629 non-null  int64 
 2   ICD9_CODE  1497629 non-null  object
 3   weight     1497629 non-null  int64 
 4   Category   1497629 non-null  object
dtypes: int64(2), object(3)
memory usage: 68.6+ MB


In [143]:
preprocessed_df.columns = ['Symptoms', 'id', 'ICD9_CODE', 'weight', 'Category']

In [144]:
# preprocessed_df_merged = pd.merge(preprocessed_df, icd_9_dataset, how = 'inner', on='ICD9_CODE')

In [145]:
# preprocessed_df_merged.info()

In [146]:
# preprocessed_df_merged.Symptoms.nunique()

In [147]:
from tqdm import tqdm

In [148]:
train_dict = {}

for i,j in tqdm(preprocessed_df.iterrows()): 
    train_dict[j.Symptoms[1:-1]] = []

    


1497629it [00:20, 72788.62it/s]


In [123]:
preprocessed_df.head()

Unnamed: 0,Symptoms,id,weight,Category
0,'Substernal Chest Pain',147171,6,39-45
1,'Substernal Chest Pain',147171,7,39-45
2,'Substernal Chest Pain',147171,4,96-99
7,'Substernal Chest Pain',147171,1,39-45
8,'Sharp Pain',147171,2,39-45


In [149]:
df = preprocessed_df.drop(columns=['ICD9_CODE'], inplace=True)

grouped_df = preprocessed_df.groupby(['Symptoms', 'id', 'Category'], as_index=False)['weight'].sum()

grouped_df

Unnamed: 0,Symptoms,id,Category,weight
0,"""15' fall""",176459,39-45,1
1,"""15' fall""",176459,74-75,1
2,"""15' fall""",176459,80-83,1
3,"""15' fall""",176459,84-95,2
4,"""30' fall""",143133,29-31,4
...,...,...,...,...
1093013,yellow-green sputum',173844,39-45,4
1093014,yellow-green sputum',173844,46-51,2
1093015,yellow-green sputum',173844,52-57,1
1093016,yellow-green sputum',173844,58-62,1


In [151]:
train_dict = {}

for i,j in tqdm(grouped_df.iterrows()): 
    train_dict[j.Symptoms[1:-1]] = []

    


1093018it [00:13, 83833.25it/s]


In [152]:
from tqdm import tqdm

for i, j in tqdm(preprocessed_df.iterrows()):
    symptoms = j.Symptoms[1:-1]
    category = j.Category
    entry = {
        'disease_code_category': category,
        'weight': j.weight,
    }

    
    # Check if the icd9_code is already in the list
    if not any(item['disease_code_category'] == category for item in train_dict[symptoms]):
        train_dict[symptoms].append(entry)


1497629it [00:23, 62815.20it/s]


In [153]:
len(train_dict.keys())

39951

In [155]:
import json
with open('train_dict_cropped_shubham.json', 'w') as json_file:
    json.dump(train_dict, json_file, indent=4)
    
    

In [156]:
for i,j in train_dataset.iterrows() : 
    print(type(j.Symptoms))
    print(type(j.short_codes))
    break

<class 'str'>
<class 'str'>


In [157]:
with open("train_dict_cropped_shubham.json") as f:
    data = json.load(f)


In [158]:
import gc
gc.collect()

25

In [169]:
uri = "neo4j://localhost:7687"
auth = ("neo4j", "neo4j_pass5")

driver = GraphDatabase.driver(uri, auth=auth)
driver.verify_connectivity()

In [176]:
with driver.session() as session:
            session.run("CREATE FULLTEXT INDEX symptomIndex FOR (s:Symptom) ON EACH [s.name]")

In [181]:
def search_symptoms(query_string):
    with driver.session() as session:
        result = session.run(
            f"CALL db.index.fulltext.queryNodes('symptomIndex', '{query_string}') "
            "YIELD node, score "
            "RETURN node.name AS symptom, score "
            "ORDER BY score DESC "
            "LIMIT 20"
        )
        symptoms = [record['symptom'] for record in result]
        return symptoms

search_results = search_symptoms('headache')
print(search_results)  # 

['HEADACHE', 'Headache', 'headache', 'Severe headache during headache episodes', 'Massive headache', 'Headache (migraine)', 'Dull headache', 'whole headache', 'worst headache', 'waxing headache', 'worse headache', 'vasospasm headache', 'Light headache', 'Migraine headache', 'headache-frontal', 'Mild Headache', 'holocranial headache', 'Mild headache', 'increased headache', 'worsening headache']


In [52]:
# def delete_all_data():
#     with driver.session(database="neo4j") as session:
#         # Execute the Cypher query to delete all nodes and relationships
#         session.run("MATCH (n) DETACH DELETE n")
#         print("All nodes and relationships have been deleted.")


# delete_all_data()



In [160]:
data['Sharp Pain']

[{'disease_code_category': '39-45', 'weight': 2},
 {'disease_code_category': '96-99', 'weight': 2},
 {'disease_code_category': '00-10', 'weight': 1},
 {'disease_code_category': '52-57', 'weight': 1},
 {'disease_code_category': '78', 'weight': 1},
 {'disease_code_category': '46-51', 'weight': 1},
 {'disease_code_category': '58-62', 'weight': 2},
 {'disease_code_category': '29-31', 'weight': 1},
 {'disease_code_category': '24-27', 'weight': 1},
 {'disease_code_category': '28', 'weight': 2},
 {'disease_code_category': '80-83', 'weight': 1},
 {'disease_code_category': '32-35', 'weight': 1}]

In [165]:
from neo4j import GraphDatabase
from tqdm import tqdm

def remove_apostrophes(text):
    return text.replace("'", "")

def generate_cypher_queries(data):
    queries = []
    unique_symptoms = set()
    unique_diseases = set()
    symptom_disease_pairs = set()
    
    for symptom, diseases_data in tqdm(data.items(), desc="Processing symptoms"):
        # Clean and process symptom
        cleaned_symptom = remove_apostrophes(symptom.strip())
        if cleaned_symptom not in unique_symptoms:
            query = f"MERGE (s:Symptom {{name: '{cleaned_symptom}'}})"
            queries.append(query)
            unique_symptoms.add(cleaned_symptom)
        
        # Process each disease associated with the symptom
        for disease in tqdm(diseases_data, desc=f"Processing diseases for {cleaned_symptom}", leave=False):
            cleaned_short_code = remove_apostrophes(disease["disease_code_category"])
            
            # Disease nodes
            if cleaned_short_code not in unique_diseases:
                query = f"MERGE (d:Disease {{title: '{cleaned_short_code}'}})"
                queries.append(query)
                unique_diseases.add(cleaned_short_code)
            
            # Relationships between symptoms and diseases
            if (cleaned_symptom, cleaned_short_code) not in symptom_disease_pairs:
                query = (f"MATCH (s:Symptom {{name: '{cleaned_symptom}'}}), "
                         f"(d:Disease {{title: '{cleaned_short_code}'}}) "
                         f"MERGE (s)-[:ASSOCIATED_WITH {{weight: {disease['weight']}}}]->(d)")
                queries.append(query)
                symptom_disease_pairs.add((cleaned_symptom, cleaned_short_code))
    
    return queries




In [166]:
queries = generate_cypher_queries(data)

Processing symptoms: 100%|██████████| 39951/39951 [01:34<00:00, 422.91it/s]


In [167]:
len(queries)

375166

In [168]:
from neo4j import GraphDatabase
from tqdm import tqdm



def run_cypher_queries(queries, batch_size=10000):
    total_queries = len(queries)
    num_batches = (total_queries + batch_size - 1) // batch_size  

    with driver.session() as session:
        for batch_start in tqdm(range(0, total_queries, batch_size), desc="Processing batches", unit="batch", total=num_batches):
            batch_end = min(batch_start + batch_size, total_queries)
            batch_queries = queries[batch_start:batch_end]

            for query in batch_queries:
                session.run(query)
    
run_cypher_queries(queries)

Processing batches: 100%|██████████| 38/38 [50:24<00:00, 79.58s/batch] 


In [29]:
# from neo4j import GraphDatabase
# from tqdm import tqdm



# def create_knowledge_graph(tx, data):
#     unique_symptoms = set()
#     unique_diseases = set()
#     symptom_disease_pairs = set()

#     for symptom, diseases in tqdm(data.items()):
#         cleaned_symptom = symptom
#         # Symptom nodes
#         if cleaned_symptom not in unique_symptoms:
#             tx.run("MERGE (s:Symptom {name: $name})", name=cleaned_symptom)
#             unique_symptoms.add(cleaned_symptom)
        
#         # Disease nodes and Relationships
#         for disease in diseases:
#             cleaned_disease_code = (disease["disease_code"]
#             cleaned_short_title = disease["short_title"]

#             if cleaned_disease_code not in unique_diseases:
#                 tx.run(
#                     """
#                     MERGE (d:Disease {code: $code})
#                     ON CREATE SET d.short_title = $short_title
#                     """,
#                     code=cleaned_disease_code,
#                     short_title=cleaned_short_title,
#                 )
#                 unique_diseases.add(cleaned_disease_code)

#             if (cleaned_symptom, cleaned_disease_code) not in symptom_disease_pairs:
#                 tx.run(
#                     """
#                     MATCH (s:Symptom {name: $symptom_name}), (d:Disease {code: $code})
#                     MERGE (s)-[r:ASSOCIATED_WITH]->(d)
#                     ON CREATE SET r.weight = $weight
#                     """,
#                     symptom_name=cleaned_symptom,
#                     code=cleaned_disease_code,
#                     weight=disease["weight"],
#                 )
#                 symptom_disease_pairs.add((cleaned_symptom, cleaned_disease_code))

# uri = "neo4j://localhost:7687"
# auth = ("neo4j", "neo4j_pass2")

# driver = GraphDatabase.driver(uri, auth=auth)
# driver.verify_connectivity()

# try:
#     with driver.session(database="neo4j") as session:
#         session.execute_write(create_knowledge_graph, data)
# finally:
#     driver.close()


In [56]:

# with driver.session(database="neo4j") as session:
#         session.execute_write(create_knowledge_graph, data)
#         session.run("RETURN 1")
