In [53]:
import os
import pandas as pd
import logging
from dotenv import load_dotenv
from neo4j import GraphDatabase

# Load credentials from .env
load_dotenv()

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class PrimeKGImporter:
    def __init__(self):
        self.uri = os.getenv("NEO4J_URI")
        self.auth = (os.getenv("NEO4J_USERNAME"), os.getenv("NEO4J_PASSWORD"))
        self.driver = GraphDatabase.driver(self.uri, auth=self.auth)
        self.driver.verify_connectivity()

    def close(self):
        self.driver.close()

    def set_constraints(self):
        # create a fast lookup indices for node index type and id to make the search faster
        # create constraint on node index since it has to be unique.
        queries = [
            "CREATE CONSTRAINT node_index_unique IF NOT EXISTS FOR (n:Entity) REQUIRE n.node_index IS UNIQUE",
            "CREATE INDEX entity_type_idx IF NOT EXISTS FOR (n:Entity) ON (n.type)",
            "CREATE INDEX entity_id_idx IF NOT EXISTS FOR (n:Entity) ON (n.id)",
            "CREATE INDEX disease_index_node_index IF NOT EXISTS FOR (n:disease) ON (n.node_index)"
        ]
        with self.driver.session() as session:
            for q in queries:
                session.run(q)
        logging.info("Constraints on node_index and indexes established.")

    def import_nodes_by_type(self, nodes_csv_path, node_type):
        
        #Imports nodes and applies the node_type as a Neo4j Label using APOC.
        #This allows to query specifically like MATCH (n:drug) or MATCH (n:disease).
        # apoc.create.addLabels allows us to set the label dynamically from the CSV/parameter
        query = """
        LOAD CSV WITH HEADERS FROM $file_url AS row
        WITH row WHERE row.node_type = $target_type
        MERGE (n:Entity {node_index: toInteger(row.node_index)})
        SET n.id = row.node_id,
            n.type = row.node_type,
            n.name = row.node_name,
            n.source = row.node_source
        WITH n, row
        CALL apoc.create.addLabels(n, [row.node_type]) YIELD node
        RETURN count(node)
        """
        file_url = f"file:///{os.path.basename(nodes_csv_path)}"
        with self.driver.session() as session:
            session.run(query, file_url=file_url, target_type=node_type)
        logging.info(f"Nodes of type '{node_type}' imported and labeled.")
    def import_edges_by_type(self, edges_csv_path, relation_type):
        query = f"""
        LOAD CSV WITH HEADERS FROM $file_url AS row
        WITH row WHERE row.relation = $rel_type
        MATCH (source:Entity {{node_index: toInteger(row.x_index)}})
        MATCH (target:Entity {{node_index: toInteger(row.y_index)}})
        CALL apoc.create.relationship(source, row.relation, {{display: row.display_relation}}, target) YIELD rel
        RETURN count(rel)
        """
        file_url = f"file:///{os.path.basename(edges_csv_path)}"
        with self.driver.session() as session:
            result = session.run(query, file_url=file_url, rel_type=relation_type)
            summary = result.consume()
            logging.info(f"Imported {relation_type} edges.")

    def enrich_drug_features(self, drug_features_csv):
        """
        Enriches 'drug' nodes with 'mechanism_of_action' and 'description'.
        """
        query = """
        LOAD CSV WITH HEADERS FROM $file_url AS row
        MATCH (n:drug {node_index: toInteger(row.node_index)})
        SET n.mechanism_of_action = row.mechanism_of_action,
            n.description = row.description
        """
        file_url = f"file:///{os.path.basename(drug_features_csv)}"
        with self.driver.session() as session:
            session.run(query, file_url=file_url)
        logging.info("Drug nodes enriched with mechanisms and descriptions.")
    def enrich_disease_features(self, disease_features_csv):
        """
        Refined enrichment:
        Uses group_name_bert -> mondo_name fallback for name.
        Uses mondo_definition -> orphanet_definition fallback for text.
        Merges all unique definitions found across duplicates.
        """

        logging.info(f"Reading {disease_features_csv}...")
        df = pd.read_csv(disease_features_csv)
        
        df['final_name'] = df['group_name_bert'].fillna(df['mondo_name'])
        
        df['temp_definition'] = df['mondo_definition'].fillna(df['orphanet_definition'])
        
        logging.info("Merging sub-types and multi-source definitions...")
        
        # Aggregation
        df_collapsed = df.groupby('node_index').agg({
            'final_name': 'first',
            'mondo_name': lambda x: " | ".join(sorted(set(str(i) for i in x if pd.notnull(i)))),
            'temp_definition': lambda x: "\n\n".join(set(str(i) for i in x if pd.notnull(i) and str(i).strip() != "")),
            'mayo_symptoms': lambda x: "\n\n".join(set(str(i) for i in x if pd.notnull(i) and str(i).strip() != ""))
        }).reset_index()

        df_collapsed = df_collapsed.replace({np.nan: None, "": None})
        total_unique = len(df_collapsed)
        
        query = """
        UNWIND $batch AS row
        MATCH (n:disease {node_index: toInteger(row.node_index)})
        SET n.name = row.final_name,
            n.subtypes = row.mondo_name,
            n.definition = row.temp_definition,
            n.symptoms = row.mayo_symptoms
        """

        batch_size = 2000 
        
        with self.driver.session() as session:
            logging.info(f"Starting batch update for {total_unique} unique nodes...")

            for start_idx in range(0, total_unique, batch_size):
                end_idx = min(start_idx + batch_size, total_unique)
                batch_df = df_collapsed.iloc[start_idx:end_idx]
                batch_data = batch_df.to_dict('records')
                
                session.run(query, batch=batch_data)
                
                percent_done = (end_idx / total_unique) * 100
                print(f"Progress: {percent_done:.2f}% ({end_idx}/{total_unique} nodes enriched)", end='\r')

            print(f"\nSuccess: {total_unique} disease nodes enriched.")
    def run_graph_analysis(self):        
        analysis_query = """
        CALL {
            MATCH (d:disease)
            RETURN 
                count(d) AS total_diseases, 
                sum(CASE WHEN d.definition IS NULL OR d.definition = "" THEN 1 ELSE 0 END) AS null_definitions
        }
    
        CALL {
            MATCH (dr:drug)
            RETURN 
                count(dr) AS total_drugs, 
                sum(CASE WHEN dr.mechanism_of_action IS NULL OR dr.mechanism_of_action = "" THEN 1 ELSE 0 END) AS null_mechanisms,
                sum(CASE WHEN dr.description IS NULL OR dr.description = "" THEN 1 ELSE 0 END) AS null_descriptions

        }
        RETURN total_diseases, null_definitions, total_drugs, null_mechanisms,null_descriptions
        """

        with self.driver.session() as session:
            result = session.run(analysis_query).single()
            
            print(f"--- Disease Analysis ---")
            print(f"Total Diseases (x): {result['total_diseases']}")
            print(f"Null Definitions (y): {result['null_definitions']}")
            
            print(f"\n--- Drug Analysis ---")
            print(f"Total Drugs (a): {result['total_drugs']}")
            print(f"Null Mechanism of Action (b): {result['null_mechanisms']}")
            print(f"Null Descriptions (c): {result['null_descriptions']}")



In [54]:

importer = PrimeKGImporter()


In [38]:
importer.import_nodes_by_type("nodes.csv","drug") # 7957 drug node 

2026-02-14 09:53:50,472 - INFO - Nodes of type 'drug' imported and labeled.


In [39]:
importer.import_nodes_by_type("nodes.csv","disease") # 17080 disease node 

2026-02-14 09:53:53,206 - INFO - Nodes of type 'disease' imported and labeled.


In [40]:
importer.set_constraints()

2026-02-14 09:53:54,139 - INFO - Received notification from DBMS server: {severity: INFORMATION} {code: Neo.ClientNotification.Schema.IndexOrConstraintAlreadyExists} {category: SCHEMA} {title: `CREATE CONSTRAINT node_index_unique IF NOT EXISTS FOR (e:Entity) REQUIRE (e.node_index) IS UNIQUE` has no effect.} {description: `CONSTRAINT node_index_unique FOR (e:Entity) REQUIRE (e.node_index) IS UNIQUE` already exists.} {position: None} for query: 'CREATE CONSTRAINT node_index_unique IF NOT EXISTS FOR (n:Entity) REQUIRE n.node_index IS UNIQUE'
2026-02-14 09:53:54,151 - INFO - Received notification from DBMS server: {severity: INFORMATION} {code: Neo.ClientNotification.Schema.IndexOrConstraintAlreadyExists} {category: SCHEMA} {title: `CREATE RANGE INDEX entity_type_idx IF NOT EXISTS FOR (e:Entity) ON (e.type)` has no effect.} {description: `RANGE INDEX index_3d7b3819 FOR (e:Entity) ON (e.type)` already exists.} {position: None} for query: 'CREATE INDEX entity_type_idx IF NOT EXISTS FOR (n:En

In [41]:
importer.enrich_drug_features("drug_features.csv")

2026-02-14 09:55:05,974 - INFO - Drug nodes enriched with mechanisms and descriptions.


In [None]:
import pandas as pd
import csv
input_path="import/disease_features.csv"
output_path="import/disease_features_cleaned.csv"
df = pd.read_csv(input_path)

# Clean common issues in text columns (mondo_definition, etc.)
# Replace actual double quotes inside the text with single quotes 
# or escape them. Single quotes are safer for Neo4j LOAD CSV.
text_columns = ['mondo_definition', 'mayo_symptoms', 'orphanet_definition']

for col in text_columns:
    if col in df.columns:
        # Replace " with ' to avoid breaking the CSV structure
        df[col] = df[col].astype(str).str.replace('"', "'", regex=False)

# Export with explicit quoting rules
df.to_csv(
    output_path, 
    index=False, 
    quoting=csv.QUOTE_ALL, 
    quotechar='"', 
    escapechar='\\' # Adds a backslash if it finds a stray quote
)

print(f"Cleaned file saved to {output_path}")


Cleaned file saved to import/disease_features_cleaned.csv


In [42]:
import numpy as np
importer.enrich_disease_features("import/disease_features_cleaned.csv")

2026-02-14 09:55:13,207 - INFO - Reading import/disease_features_cleaned.csv...
2026-02-14 09:55:14,821 - INFO - Merging sub-types and multi-source definitions...
2026-02-14 09:55:16,425 - INFO - Starting batch update for 17080 unique nodes...
2026-02-14 09:55:16,440 - INFO - Received notification from DBMS server: {severity: INFORMATION} {code: Neo.ClientNotification.Schema.IndexOrConstraintAlreadyExists} {category: SCHEMA} {title: `CREATE RANGE INDEX disease_node_idx IF NOT EXISTS FOR (e:disease) ON (e.node_index)` has no effect.} {description: `RANGE INDEX disease_index_node_index FOR (e:disease) ON (e.node_index)` already exists.} {position: None} for query: 'CREATE INDEX disease_node_idx IF NOT EXISTS FOR (n:disease) ON (n.node_index)'


Progress: 100.00% (17080/17080 nodes enriched)
Success: 17080 disease nodes enriched.


In [None]:

input_path="import/disease_features.csv"

df = pd.read_csv(input_path)

print(f"Total Rows in CSV: {len(df)}")
print(f"Unique node_index count: {df['node_index'].nunique()}")

duplicates = df[df.duplicated(subset=['node_index'], keep=False)].sort_values('node_index')
if not duplicates.empty:
    print("\nExample of duplicated node_index in CSV:")
    print(duplicates[['node_index', 'mondo_name', 'mondo_id',"mondo_definition","group_name_bert"]].head(4))

Total Rows in CSV: 44133
Unique node_index count: 17080

Example of duplicated node_index in CSV:
       node_index                      mondo_name  mondo_id  \
31059       27158  osteogenesis imperfecta type 1      8146   
31048       27158  osteogenesis imperfecta type 5     12591   
31049       27158  osteogenesis imperfecta type 5     12591   
31050       27158  osteogenesis imperfecta type 5     12591   

                                        mondo_definition  \
31059  Osteogenesis imperfecta type I is a mild type ...   
31048  Osteogenesis imperfecta type V is a moderate t...   
31049  Osteogenesis imperfecta type V is a moderate t...   
31050  Osteogenesis imperfecta type V is a moderate t...   

               group_name_bert  
31059  osteogenesis imperfecta  
31048  osteogenesis imperfecta  
31049  osteogenesis imperfecta  
31050  osteogenesis imperfecta  


In [11]:
# Group by node_index and see how many unique definitions/symptoms each one has
diff_check = df.groupby('node_index').agg({
    'mondo_definition': 'nunique',
    'mayo_symptoms': 'nunique',
    'mondo_name': 'unique'
})

# Filter for nodes that have more than one unique definition
conflicts = diff_check[diff_check['mondo_definition'] > 1]

print(f"Number of diseases with conflicting definitions: {len(conflicts)}")
if not conflicts.empty:
    print("\nExample of conflicting names for one index:")
    print(conflicts['mondo_name'].head(1).values)

Number of diseases with conflicting definitions: 881

Example of conflicting names for one index:
[array(['osteogenesis imperfecta type 13',
        'osteogenesis imperfecta type 11',
        'osteogenesis imperfecta type 17',
        'osteogenesis imperfecta type 12',
        'osteogenesis imperfecta type 5', 'osteogenesis imperfecta type 7',
        'osteogenesis imperfecta, type 21',
        'osteogenesis imperfecta type 1', 'osteogenesis imperfecta type 4',
        'osteogenesis imperfecta, type 20',
        'osteogenesis imperfecta type 10',
        'osteogenesis imperfecta, type 18',
        'osteogenesis imperfecta type 16',
        'osteogenesis imperfecta type 9',
        'osteogenesis imperfecta, type 19',
        'osteogenesis imperfecta type 3',
        'osteogenesis imperfecta type 15',
        'osteogenesis imperfecta type 2', 'osteogenesis imperfecta type 6',
        'osteogenesis imperfecta type 14',
        'osteogenesis imperfecta type 8', 'osteogenesis imperfecta'],


In [55]:
importer.run_graph_analysis()



--- Disease Analysis ---
Total Diseases (x): 17080
Null Definitions (y): 4182

--- Drug Analysis ---
Total Drugs (a): 7957
Null Mechanism of Action (b): 4715
Null Descriptions (c): 3366
