In [2]:
import pandas as pd
import os
from neo4j import GraphDatabase
from dotenv import load_dotenv

train_dataset = pd.read_csv('symptoms_train.csv')
train_dataset.head()


Unnamed: 0.1,Unnamed: 0,Symptoms,id,text,long_texts,short_texts,discharge_summary,short_codes
0,0,"['Substernal Chest Pain', 'Sharp Pain', 'Cresc...",147171,CHIEF COMPLAINT: Substernal Chest Pain\n\nPRES...,Acute myocardial infarction of other anterior ...,"AMI anterior wall, init,Ac systolic hrt failur...",Admission Date: [**2102-9-26**] ...,4101142821997142714140142804273145829
1,1,['Back pain'],199961,CHIEF COMPLAINT: \n\nPRESENT ILLNESS: The pati...,"Ankylosing spondylitis,Hypertensive chronic ki...","Ankylosing spondylitis,Hyp kid NOS w cr kid V,...",Admission Date: [**2115-6-29**] Dischar...,"7200,40391,8052,8471,E8859,78057,2859,25060"
2,2,"['Shortness of breath', 'Cough', 'Occasional n...",136812,CHIEF COMPLAINT: \n\nPRESENT ILLNESS: This is ...,Obstructive chronic bronchitis with (acute) ex...,"Obs chr bronc w(ac) exac,Pneumonia, organism N...",Admission Date: [**2106-4-14**] Dischar...,491214862800427894261143889729892449
3,3,"['left arm pain', 'left leg pain', 'pulmonary ...",175700,CHIEF COMPLAINT: s/p rollover MVC with prolong...,"Closed fracture of shaft of fibula with tibia,...","Fx shaft fib w tib-clos,Pneumococcal pneumonia...",Admission Date: [**2159-2-9**] D...,"82322,481,86121,5180,2851,81322,8072,E8160,883..."
4,4,"['Palpitations', 'Shortness of breath', 'Sore ...",193486,CHIEF COMPLAINT: Shortness of breath\n\nPRESEN...,"Other specified cardiac dysrhythmias,End stage...","Cardiac dysrhythmias NEC,End stage renal disea...",Admission Date: [**2136-10-4**] ...,"42789,5856,6822,6164,42830,2761,2869,4160,0411..."


In [3]:
train_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 33684 entries, 0 to 33683
Data columns (total 8 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   Unnamed: 0         33684 non-null  int64 
 1   Symptoms           33684 non-null  object
 2   id                 33684 non-null  int64 
 3   text               33684 non-null  object
 4   long_texts         33684 non-null  object
 5   short_texts        33684 non-null  object
 6   discharge_summary  33684 non-null  object
 7   short_codes        33684 non-null  object
dtypes: int64(2), object(6)
memory usage: 2.1+ MB


In [2]:
icd_9_dataset = pd.read_csv('D_ICD_DIAGNOSES.csv')
icd_9_dataset.head()

Unnamed: 0,ROW_ID,ICD9_CODE,SHORT_TITLE,LONG_TITLE
0,174,1166,TB pneumonia-oth test,"Tuberculous pneumonia [any form], tubercle bac..."
1,175,1170,TB pneumothorax-unspec,"Tuberculous pneumothorax, unspecified"
2,176,1171,TB pneumothorax-no exam,"Tuberculous pneumothorax, bacteriological or h..."
3,177,1172,TB pneumothorx-exam unkn,"Tuberculous pneumothorax, bacteriological or h..."
4,178,1173,TB pneumothorax-micro dx,"Tuberculous pneumothorax, tubercle bacilli fou..."


In [3]:
preprocessed_df = train_dataset[['Symptoms','id', 'short_codes' ]].copy()
preprocessed_df.head()

Unnamed: 0,Symptoms,id,short_codes
0,"['Substernal Chest Pain', 'Sharp Pain', 'Cresc...",147171,4101142821997142714140142804273145829
1,['Back pain'],199961,"7200,40391,8052,8471,E8859,78057,2859,25060"
2,"['Shortness of breath', 'Cough', 'Occasional n...",136812,491214862800427894261143889729892449
3,"['left arm pain', 'left leg pain', 'pulmonary ...",175700,"82322,481,86121,5180,2851,81322,8072,E8160,883..."
4,"['Palpitations', 'Shortness of breath', 'Sore ...",193486,"42789,5856,6822,6164,42830,2761,2869,4160,0411..."


In [4]:
preprocessed_df['Symptoms'] = preprocessed_df['Symptoms'].apply(lambda x : str(x[1:-1]).split(', '))
preprocessed_df.head()

Unnamed: 0,Symptoms,id,short_codes
0,"['Substernal Chest Pain', 'Sharp Pain', 'Cresc...",147171,4101142821997142714140142804273145829
1,['Back pain'],199961,"7200,40391,8052,8471,E8859,78057,2859,25060"
2,"['Shortness of breath', 'Cough', 'Occasional n...",136812,491214862800427894261143889729892449
3,"['left arm pain', 'left leg pain', 'pulmonary ...",175700,"82322,481,86121,5180,2851,81322,8072,E8160,883..."
4,"['Palpitations', 'Shortness of breath', 'Sore ...",193486,"42789,5856,6822,6164,42830,2761,2869,4160,0411..."


In [5]:
preprocessed_df = preprocessed_df.explode('Symptoms')
preprocessed_df = preprocessed_df.assign(short_codes=preprocessed_df['short_codes'].str.split(',')).explode('short_codes')


preprocessed_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 2020063 entries, 0 to 33683
Data columns (total 3 columns):
 #   Column       Dtype 
---  ------       ----- 
 0   Symptoms     object
 1   id           int64 
 2   short_codes  object
dtypes: int64(1), object(2)
memory usage: 61.6+ MB


In [6]:
preprocessed_df.head()

Unnamed: 0,Symptoms,id,short_codes
0,'Substernal Chest Pain',147171,41011
0,'Substernal Chest Pain',147171,42821
0,'Substernal Chest Pain',147171,9971
0,'Substernal Chest Pain',147171,4271
0,'Substernal Chest Pain',147171,41401


In [7]:
preprocessed_df.Symptoms.nunique()

39994

In [8]:
preprocessed_df.head()

Unnamed: 0,Symptoms,id,short_codes
0,'Substernal Chest Pain',147171,41011
0,'Substernal Chest Pain',147171,42821
0,'Substernal Chest Pain',147171,9971
0,'Substernal Chest Pain',147171,4271
0,'Substernal Chest Pain',147171,41401


In [9]:
preprocessed_df.columns = ['Symptoms', 'id', 'ICD9_CODE']

In [10]:
preprocessed_df = preprocessed_df.reset_index(drop= True)

In [11]:
preprocessed_df.loc[
    preprocessed_df['ICD9_CODE'].str.startswith("V"), 'ICD9_CODE'] = preprocessed_df.ICD9_CODE.apply(
    lambda x: x[:4])
preprocessed_df.loc[
    preprocessed_df['ICD9_CODE'].str.startswith("E"), 'ICD9_CODE'] = preprocessed_df.ICD9_CODE.apply(
    lambda x: x[:4])
preprocessed_df.loc[(~preprocessed_df.ICD9_CODE.str.startswith("E")) & (
    ~preprocessed_df.ICD9_CODE.str.startswith("V")), 'ICD9_CODE'] = preprocessed_df.ICD9_CODE.apply(
    lambda x: x[:3])

In [12]:
preprocessed_df['weight'] = preprocessed_df.groupby(['Symptoms', 'ICD9_CODE'])['Symptoms'].transform('size')

preprocessed_df.head()


Unnamed: 0,Symptoms,id,ICD9_CODE,weight
0,'Substernal Chest Pain',147171,410,2
1,'Substernal Chest Pain',147171,428,3
2,'Substernal Chest Pain',147171,997,2
3,'Substernal Chest Pain',147171,427,4
4,'Substernal Chest Pain',147171,414,4


In [13]:
# preprocessed_df['ICD9_CODE']  = preprocessed_df['ICD9_CODE'].astype('str')

In [14]:
preprocessed_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2020063 entries, 0 to 2020062
Data columns (total 4 columns):
 #   Column     Dtype 
---  ------     ----- 
 0   Symptoms   object
 1   id         int64 
 2   ICD9_CODE  object
 3   weight     int64 
dtypes: int64(2), object(2)
memory usage: 61.6+ MB


In [15]:
preprocessed_df.columns = ['Symptoms', 'id', 'ICD9_CODE', 'weight']

In [16]:
# preprocessed_df_merged = pd.merge(preprocessed_df, icd_9_dataset, how = 'inner', on='ICD9_CODE')

In [17]:
# preprocessed_df_merged.info()

In [18]:
# preprocessed_df_merged.Symptoms.nunique()

In [19]:
from tqdm import tqdm

In [20]:
train_dict = {}

for i,j in tqdm(preprocessed_df.iterrows()): 
    train_dict[j.Symptoms[1:-1]] = []

    


2020063it [00:24, 81025.22it/s]


In [21]:
preprocessed_df.head()

Unnamed: 0,Symptoms,id,ICD9_CODE,weight
0,'Substernal Chest Pain',147171,410,2
1,'Substernal Chest Pain',147171,428,3
2,'Substernal Chest Pain',147171,997,2
3,'Substernal Chest Pain',147171,427,4
4,'Substernal Chest Pain',147171,414,4


In [22]:
from tqdm import tqdm

for i, j in tqdm(preprocessed_df.iterrows()):
    symptoms = j.Symptoms[1:-1]
    icd9_code = j.ICD9_CODE
    entry = {
        'disease_code': icd9_code,
        'weight': j.weight,
    }

    
    # Check if the icd9_code is already in the list
    if not any(item['disease_code'] == icd9_code for item in train_dict[symptoms]):
        train_dict[symptoms].append(entry)


2020063it [00:39, 50703.77it/s]


In [24]:
len(train_dict.keys())

39951

In [1]:
import json
with open('train_dict_cropped.json', 'w') as json_file:
    json.dump(train_dict, json_file, indent=4)
    
    

NameError: name 'train_dict' is not defined

In [26]:
for i,j in train_dataset.iterrows() : 
    print(type(j.Symptoms))
    print(type(j.short_codes))
    break

<class 'str'>
<class 'str'>


In [27]:
with open("train_dict_cropped.json") as f:
    data = json.load(f)


In [28]:
import gc
gc.collect()

0

In [29]:
uri = "neo4j://localhost:7687"
auth = ("neo4j", "neo4j_pass4")

driver = GraphDatabase.driver(uri, auth=auth)
driver.verify_connectivity()

In [52]:
# def delete_all_data():
#     with driver.session(database="neo4j") as session:
#         # Execute the Cypher query to delete all nodes and relationships
#         session.run("MATCH (n) DETACH DELETE n")
#         print("All nodes and relationships have been deleted.")


# delete_all_data()



In [87]:
from neo4j import GraphDatabase
from tqdm import tqdm

def remove_apostrophes(text):
    return text.replace("'", "")

def generate_cypher_queries(data):
    queries = []
    unique_symptoms = set()
    unique_diseases = set()
    symptom_disease_pairs = set()
    
    for symptom, diseases_data in tqdm(data.items(), desc="Processing symptoms"):
        # Clean and process symptom
        cleaned_symptom = remove_apostrophes(symptom.strip())
        if cleaned_symptom not in unique_symptoms:
            query = f"MERGE (s:Symptom {{name: '{cleaned_symptom}'}})"
            queries.append(query)
            unique_symptoms.add(cleaned_symptom)
        
        # Process each disease associated with the symptom
        for disease in tqdm(diseases_data, desc=f"Processing diseases for {cleaned_symptom}", leave=False):
            cleaned_short_code = remove_apostrophes(disease["disease_code"].strip())
            
            # Disease nodes
            if cleaned_short_code not in unique_diseases:
                query = f"MERGE (d:Disease {{title: '{cleaned_short_code}'}})"
                queries.append(query)
                unique_diseases.add(cleaned_short_code)
            
            # Relationships between symptoms and diseases
            if (cleaned_symptom, cleaned_short_code) not in symptom_disease_pairs:
                query = (f"MATCH (s:Symptom {{name: '{cleaned_symptom}'}}), "
                         f"(d:Disease {{title: '{cleaned_short_code}'}}) "
                         f"MERGE (s)-[:ASSOCIATED_WITH {{weight: {disease['weight']}}}]->(d)")
                queries.append(query)
                symptom_disease_pairs.add((cleaned_symptom, cleaned_short_code))
    
    return queries




In [88]:
queries = generate_cypher_queries(data)

Processing symptoms: 100%|██████████| 39951/39951 [01:22<00:00, 482.81it/s]


In [89]:
len(queries)

851028

In [90]:
from neo4j import GraphDatabase
from tqdm import tqdm



def run_cypher_queries(queries, batch_size=10000):
    total_queries = len(queries)
    num_batches = (total_queries + batch_size - 1) // batch_size  

    with driver.session() as session:
        for batch_start in tqdm(range(0, total_queries, batch_size), desc="Processing batches", unit="batch", total=num_batches):
            batch_end = min(batch_start + batch_size, total_queries)
            batch_queries = queries[batch_start:batch_end]

            for query in batch_queries:
                session.run(query)
    
run_cypher_queries(queries)

Processing batches: 100%|██████████| 86/86 [1:24:14<00:00, 58.77s/batch] 


In [29]:
# from neo4j import GraphDatabase
# from tqdm import tqdm



# def create_knowledge_graph(tx, data):
#     unique_symptoms = set()
#     unique_diseases = set()
#     symptom_disease_pairs = set()

#     for symptom, diseases in tqdm(data.items()):
#         cleaned_symptom = symptom
#         # Symptom nodes
#         if cleaned_symptom not in unique_symptoms:
#             tx.run("MERGE (s:Symptom {name: $name})", name=cleaned_symptom)
#             unique_symptoms.add(cleaned_symptom)
        
#         # Disease nodes and Relationships
#         for disease in diseases:
#             cleaned_disease_code = (disease["disease_code"]
#             cleaned_short_title = disease["short_title"]

#             if cleaned_disease_code not in unique_diseases:
#                 tx.run(
#                     """
#                     MERGE (d:Disease {code: $code})
#                     ON CREATE SET d.short_title = $short_title
#                     """,
#                     code=cleaned_disease_code,
#                     short_title=cleaned_short_title,
#                 )
#                 unique_diseases.add(cleaned_disease_code)

#             if (cleaned_symptom, cleaned_disease_code) not in symptom_disease_pairs:
#                 tx.run(
#                     """
#                     MATCH (s:Symptom {name: $symptom_name}), (d:Disease {code: $code})
#                     MERGE (s)-[r:ASSOCIATED_WITH]->(d)
#                     ON CREATE SET r.weight = $weight
#                     """,
#                     symptom_name=cleaned_symptom,
#                     code=cleaned_disease_code,
#                     weight=disease["weight"],
#                 )
#                 symptom_disease_pairs.add((cleaned_symptom, cleaned_disease_code))

# uri = "neo4j://localhost:7687"
# auth = ("neo4j", "neo4j_pass2")

# driver = GraphDatabase.driver(uri, auth=auth)
# driver.verify_connectivity()

# try:
#     with driver.session(database="neo4j") as session:
#         session.execute_write(create_knowledge_graph, data)
# finally:
#     driver.close()


In [56]:

# with driver.session(database="neo4j") as session:
#         session.execute_write(create_knowledge_graph, data)
#         session.run("RETURN 1")
