In [1]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from neo4j import GraphDatabase
import os
from getpass import getpass

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import dotenv
# Load environment variables from .env file
dotenv.load_dotenv()

True

In [3]:
path = os.getcwd()
full_path = path.replace('graphRAG', 'dataset\health_reports_data.csv')
df = pd.read_csv(full_path)

In [4]:
# Standardize all text data to lowercase for consistency
for col in ['actual_disease', 'intervention_target_disease', 'region', 'intervention_region', 'severity', 'day_of_week']:
    df[col] = df[col].astype(str).str.lower()

# Convert to datetime objects
df['timestamp'] = pd.to_datetime(df['timestamp'])
df['intervention_start_date'] = pd.to_datetime(df['intervention_start_date'].fillna(pd.Timestamp('1970-01-01')))

# Add unique IDs
df['case_id'] = [f"case_{i}" for i in range(len(df))]
df['patient_id'] = [f"patient_{i}" for i in range(len(df))]

In [5]:
model = SentenceTransformer('all-MiniLM-L6-v2') 

In [6]:
df['symptom_embedding'] = model.encode(df['symptoms_text'].tolist(), show_progress_bar=True).tolist()

Batches:   0%|          | 0/313 [00:00<?, ?it/s]

Batches: 100%|██████████| 313/313 [00:27<00:00, 11.28it/s]


In [7]:
records = df.to_dict('records')

In [8]:
disease_category_map = {
    'gastroenteritis': 'Gastrointestinal',
    'malaria': 'Parasitic',
    'covid19': 'Viral',
    'food_poisoning': 'Bacterial/Toxin-related',
    'influenza': 'Viral',
    'meningitis': 'Infectious',
    'pneumonia': 'Respiratory'
}

In [9]:
for rec in records:
    rec['disease_category'] = disease_category_map.get(rec['actual_disease'], 'Uncategorized')

In [10]:
NEO4J_URI      = os.getenv("NEO4J_URI")
NEO4J_USER     = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")

In [11]:
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

In [12]:
constraints = [
    "CREATE CONSTRAINT IF NOT EXISTS FOR (p:Patient) REQUIRE p.patientId IS UNIQUE",
    "CREATE CONSTRAINT IF NOT EXISTS FOR (c:Case) REQUIRE c.caseId IS UNIQUE",
    "CREATE CONSTRAINT IF NOT EXISTS FOR (d:Disease) REQUIRE d.name IS UNIQUE",
    "CREATE CONSTRAINT IF NOT EXISTS FOR (s:Symptom) REQUIRE s.name IS UNIQUE",
    "CREATE CONSTRAINT IF NOT EXISTS FOR (r:Region) REQUIRE r.name IS UNIQUE",
    "CREATE CONSTRAINT IF NOT EXISTS FOR (i:Intervention) REQUIRE i.id IS UNIQUE",
    "CREATE CONSTRAINT IF NOT EXISTS FOR (y:Year) REQUIRE y.year IS UNIQUE",
    "CREATE CONSTRAINT IF NOT EXISTS FOR (sl:SeverityLevel) REQUIRE sl.level IS UNIQUE",
    "CREATE CONSTRAINT IF NOT EXISTS FOR (dc:DiseaseCategory) REQUIRE dc.name IS UNIQUE",
]

In [13]:
with driver.session(database="neo4j") as session:
    for constraint in constraints:
        try:
            session.run(constraint)
            print(f"Applied or verified constraint: {constraint.split('FOR')[1].strip()}")
        except Exception as e:
            print(f"Could not apply constraint: {constraint}. Error: {e}")

Applied or verified constraint: (p:Patient) REQUIRE p.patientId IS UNIQUE
Applied or verified constraint: (c:Case) REQUIRE c.caseId IS UNIQUE
Applied or verified constraint: (d:Disease) REQUIRE d.name IS UNIQUE
Applied or verified constraint: (s:Symptom) REQUIRE s.name IS UNIQUE
Applied or verified constraint: (r:Region) REQUIRE r.name IS UNIQUE
Applied or verified constraint: (i:Intervention) REQUIRE i.id IS UNIQUE
Applied or verified constraint: (y:Year) REQUIRE y.year IS UNIQUE
Applied or verified constraint: (sl:SeverityLevel) REQUIRE sl.level IS UNIQUE
Applied or verified constraint: (dc:DiseaseCategory) REQUIRE dc.name IS UNIQUE


In [14]:
ingestion_query = """
UNWIND $records AS row

// MERGE core entities to prevent duplicates
MERGE (patient:Patient {patientId: row.patient_id})
    ON CREATE SET patient.age = toInteger(row.age), patient.gender = row.gender, patient.ageGroup = row.age_group
MERGE (disease:Disease {name: row.actual_disease})
MERGE (case_region:Region {name: row.region})
MERGE (intervention_region:Region {name: row.intervention_region})
MERGE (target_disease:Disease {name: row.intervention_target_disease})

// MERGE conceptual nodes
MERGE (severity:SeverityLevel {level: row.severity})
MERGE (category:DiseaseCategory {name: row.disease_category})

MERGE (intervention:Intervention {id: row.intervention_intervention_id})
    ON CREATE SET
        intervention.type = row.intervention_type,
        intervention.startDate = date(row.intervention_start_date),
        intervention.durationDays = toInteger(row.intervention_duration_days),
        intervention.effectivenessScore = toFloat(row.intervention_effectiveness_score),
        intervention.cost = toInteger(row.intervention_cost),
        intervention.populationAffected = toInteger(row.intervention_population_affected),
        intervention.complianceRate = toFloat(row.intervention_compliance_rate)

// MERGE Time Tree nodes
MERGE (year:Year {year: row.timestamp.year})
MERGE (month:Month {month: row.timestamp.month})
MERGE (day:Day {day: row.timestamp.day})
MERGE (weekday:DayOfWeek {name: row.day_of_week})
MERGE (hour:Hour {hour: row.hour})

// CREATE the unique Case node for each row
CREATE (case:Case {
    caseId: row.case_id,
    timestamp: datetime(row.timestamp),
    location: point({latitude: toFloat(row.latitude), longitude: toFloat(row.longitude)}),
    isOutbreakRelated: toBoolean(row.is_outbreak_related),
    contactTracingNeeded: toBoolean(row.contact_tracing_needed),
    hospitalizationRequired: toBoolean(row.hospitalization_required)
})

// Create the rich web of relationships
MERGE (patient)-[:REPORTED]->(case)
MERGE (case)-[:DIAGNOSED_WITH]->(disease)
MERGE (case)-[:HAS_SEVERITY]->(severity)
MERGE (case)-[:OCCURRED_IN]->(case_region)
MERGE (case)-[:AFFECTED_BY]->(intervention)

// Create time relationships
MERGE (case)-[:OCCURRED_ON]->(day)
MERGE (day)-[:OF_MONTH]->(month)
MERGE (month)-[:OF_YEAR]->(year)
MERGE (day)-[:IS_WEEKDAY]->(weekday)
MERGE (case)-[:OCCURRED_AT_HOUR]->(hour)
MERGE (case)-[:REPORTED_IN_YEAR]->(year) // Shortcut relationship

// Create symptom relationships
FOREACH (symptom_name IN row.symptoms_text |
    MERGE (symptom:Symptom {name: symptom_name})
    MERGE (case)-[:PRESENTED_SYMPTOM]->(symptom)
    MERGE (symptom)-[:COMMON_MANIFESTATION_OF]->(disease) // Inverse link
)

// Create intervention relationships
MERGE (intervention)-[:TARGETS_DISEASE]->(target_disease)
MERGE (intervention)-[:APPLIED_IN]->(intervention_region)

// Create conceptual & inferred relationships
MERGE (disease)-[:IS_A_TYPE_OF]->(category)
MERGE (patient)-[:LIVES_IN]->(case_region)
MERGE (patient)-[:HAS_HISTORY_OF]->(disease)
MERGE (disease)-[:PREVALENT_IN]->(case_region)
"""

In [15]:
def ingest_data_in_batches(driver, query, data, batch_size=500):
    total = len(data)
    print(f"Starting ingestion of {total} records with the hyper-robust schema...")
    with driver.session(database="neo4j") as session:
        for i in range(0, total, batch_size):
            batch = data[i:min(i + batch_size, total)]
            session.run(query, records=batch)
            print(f"Processed {len(batch)} records. ({i + len(batch)}/{total})")

In [16]:
ingest_data_in_batches(driver, ingestion_query, records, batch_size=500)
print("\nIngestion complete. Your graph is now ready for advanced querying.")
driver.close()

Starting ingestion of 10000 records with the hyper-robust schema...
Processed 500 records. (500/10000)
Processed 500 records. (1000/10000)
Processed 500 records. (1500/10000)
Processed 500 records. (2000/10000)
Processed 500 records. (2500/10000)
Processed 500 records. (3000/10000)
Processed 500 records. (3500/10000)
Processed 500 records. (4000/10000)
Processed 500 records. (4500/10000)
Processed 500 records. (5000/10000)
Processed 500 records. (5500/10000)
Processed 500 records. (6000/10000)
Processed 500 records. (6500/10000)
Processed 500 records. (7000/10000)
Processed 500 records. (7500/10000)
Processed 500 records. (8000/10000)
Processed 500 records. (8500/10000)
Processed 500 records. (9000/10000)
Processed 500 records. (9500/10000)
Processed 500 records. (10000/10000)

Ingestion complete. Your graph is now ready for advanced querying.


In [2]:
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")

In [4]:
graph = Neo4jGraph(
    url=os.getenv("NEO4J_URI"), 
    username="neo4j", 
    password=os.getenv("NEO4J_PASSWORD")
)

  graph = Neo4jGraph(


In [5]:
graph.refresh_schema()
print(graph.schema)

Node properties:
Disease {name: STRING}
Region {name: STRING}
Intervention {id: STRING, type: STRING, cost: INTEGER, startDate: DATE, durationDays: INTEGER, effectivenessScore: FLOAT, populationAffected: INTEGER, complianceRate: FLOAT}
Symptom {name: STRING}
Year {year: INTEGER}
Month {month: INTEGER}
Day {day: INTEGER}
DayOfWeek {name: STRING}
Hour {hour: INTEGER}
Case {timestamp: DATE_TIME, caseId: STRING, contactTracingNeeded: BOOLEAN, location: POINT, isOutbreakRelated: BOOLEAN, hospitalizationRequired: BOOLEAN}
Patient {gender: STRING, age: INTEGER, patientId: STRING, ageGroup: STRING}
SeverityLevel {level: STRING}
DiseaseCategory {name: STRING}
Relationship properties:

The relationships:
(:Disease)-[:IS_A_TYPE_OF]->(:DiseaseCategory)
(:Disease)-[:PREVALENT_IN]->(:Region)
(:Intervention)-[:APPLIED_IN]->(:Region)
(:Intervention)-[:TARGETS_DISEASE]->(:Disease)
(:Symptom)-[:COMMON_MANIFESTATION_OF]->(:Disease)
(:Month)-[:OF_YEAR]->(:Year)
(:Day)-[:OF_MONTH]->(:Month)
(:Day)-[:IS_WEE

In [6]:
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-04-17")

In [7]:
llm = ChatGroq(
    groq_api_key=os.environ.get('GROQ_API'),
    model_name='meta-llama/llama-4-maverick-17b-128e-instruct'
)

In [8]:
CYPHER_GENERATION_TEMPLATE_XML = """<cypher_generation_prompt>
    <instructions>
        <title>Instructions for Neo4j Cypher Query Generation</title>
        <rule>You are a world-class Neo4j Cypher query translator. Your sole purpose is to convert a user's question into a valid and efficient Cypher query based on the provided graph schema.</rule>
        <rule>Strictly adhere to the schema. Never use node labels, relationship types, or property names that are not explicitly defined in the `<schema>` block.</rule>
        <important rule>If a user asks for anything related to a region, e.g., 'coastal region', you must insert an underscore and write 'coastal_region' in the query.</important rule>
        <rule>Your output MUST be a single, valid Cypher query and nothing else.</rule>
    </instructions>

    <schema>
        {schema}
    </schema>

    <task>
        <title>Current Task</title>
        <question>{question}</question>
        <cypher>
        </cypher>
    </task>
</cypher_generation_prompt>"""


In [9]:
cypher_prompt = PromptTemplate(
    input_variables=["schema", "question"],
    template=CYPHER_GENERATION_TEMPLATE_XML
)

In [10]:
chain = GraphCypherQAChain.from_llm(
    graph=graph,
    llm=llm,
    verbose=True, 
    cypher_prompt=cypher_prompt,
    allow_dangerous_requests=True
)

In [11]:
def ask_question(query: str):
    """
    Invokes the QA chain and prints the result in a user-friendly format.
    """
    try:
        result = chain.invoke({"query": query})
        print(f"❓ Question: {query}")
        # The final answer is in the 'result' key
        print(f"🤖 Answer: {result['result']}")
    except Exception as e:
        print(f"An error occurred: {e}")
    print("-" * 50)

In [12]:
print("Running corrected queries...\n")

# Test 1: The COVID-19 query 
ask_question("how many cases were reported for covid19 and tell which year and month it was reported in?")

# Test 2: The Intervention query 
ask_question("What was the effectiveness score for interventions targeting malaria?")

# Test 3: Other queries
ask_question("Which disease was most common in the coastal region in May?")
ask_question("What is the average age of male patients diagnosed with malaria?")

Running corrected queries...



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (d:Disease {name: 'covid19'})<-[:DIAGNOSED_WITH]-(c:Case)-[:REPORTED_IN_YEAR]->(y:Year)
RETURN COUNT(c) AS num_cases, y.year AS year, c.timestamp AS timestamp
ORDER BY y.year
[0m
Full Context:
[32;1m[1;3m[{'num_cases': 1, 'year': 2020, 'timestamp': neo4j.time.DateTime(2020, 6, 5, 0, 0, 0, 0, tzinfo=<UTC>)}, {'num_cases': 1, 'year': 2020, 'timestamp': neo4j.time.DateTime(2020, 9, 11, 0, 0, 0, 0, tzinfo=<UTC>)}, {'num_cases': 1, 'year': 2020, 'timestamp': neo4j.time.DateTime(2020, 11, 20, 0, 0, 0, 0, tzinfo=<UTC>)}, {'num_cases': 1, 'year': 2020, 'timestamp': neo4j.time.DateTime(2020, 7, 15, 0, 0, 0, 0, tzinfo=<UTC>)}, {'num_cases': 2, 'year': 2020, 'timestamp': neo4j.time.DateTime(2020, 5, 8, 0, 0, 0, 0, tzinfo=<UTC>)}, {'num_cases': 2, 'year': 2020, 'timestamp': neo4j.time.DateTime(2020, 3, 1, 0, 0, 0, 0, tzinfo=<UTC>)}, {'num_cases': 1, 'year': 2020, 'times

In [13]:
ask_question("What is the most common symptom for patients diagnosed with pneumonia?")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'pneumonia'})<-[:DIAGNOSED_WITH]-(c:Case)-[:PRESENTED_SYMPTOM]->(s:Symptom)
RETURN s.name, COUNT(s) AS symptom_count
ORDER BY symptom_count DESC
LIMIT 1[0m
Full Context:
[32;1m[1;3m[{'s.name': 'Patient presents with chest discomfort, high temperature, cough with phlegm and chills', 'symptom_count': 4}][0m

[1m> Finished chain.[0m
❓ Question: What is the most common symptom for patients diagnosed with pneumonia?
🤖 Answer: Patients diagnosed with pneumonia commonly present with chest discomfort, high temperature, cough with phlegm, and chills. The most common symptoms are likely related to these four, with the patient in the given case having 4 symptoms. However, to directly answer the question, the data shows that the patient presents with 4 symptoms, but it doesn't specify one as more common than the others. A more accurate interpretation would be that the patient presents 

In [16]:
ask_question("What was the most common disease reported in the year 2023?")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease)<-[:DIAGNOSED_WITH]-(c:Case)-[:REPORTED_IN_YEAR]->(y:Year)
WHERE y.year = 2023
RETURN d.name AS diseaseName, COUNT(c) AS caseCount
ORDER BY caseCount DESC
LIMIT 1[0m
Full Context:
[32;1m[1;3m[{'diseaseName': 'gastroenteritis', 'caseCount': 1144}][0m

[1m> Finished chain.[0m
❓ Question: What was the most common disease reported in the year 2023?
🤖 Answer: The most common disease reported was gastroenteritis.
--------------------------------------------------


In [17]:
ask_question("What is the most common symptom for patients diagnosed with pneumonia in the last year?")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (d:Disease {name: 'pneumonia'})<-[:DIAGNOSED_WITH]-(c:Case)-[:PRESENTED_SYMPTOM]->(s:Symptom),
      (c:Case)-[:REPORTED_IN_YEAR]->(y:Year)
WHERE y.year = date().year - 1
RETURN s.name, COUNT(c) AS symptom_count
ORDER BY symptom_count DESC
LIMIT 1
[0m
Full Context:
[32;1m[1;3m[{'s.name': 'Clinical presentation: cough with phlegm, chills and chest pain', 'symptom_count': 2}][0m

[1m> Finished chain.[0m
❓ Question: What is the most common symptom for patients diagnosed with pneumonia in the last year?
🤖 Answer: The most common symptoms are cough with phlegm, chills, and chest pain.
--------------------------------------------------
