In [1]:
import os
import pandas as pd
from tqdm import tqdm
import openai
from neo4j import GraphDatabase
from tenacity import retry, stop_after_attempt, wait_random_exponential
import re

In [2]:
# pip install --upgrade openai

In [3]:
openai.api_key = "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" 

NEO4J_URI = "neo4j+s://05c3a7eb.databases.neo4j.io"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "B99jg7bGi7__z63ibtnFj_Sj0vkc-HAaaSCcmlKUuYA"


MIMIC_DIR = "/work/pi_hongyu_umass_edu/shared/physionet/mimiciii/" 
NOTE_CSV_PATH = os.path.join(MIMIC_DIR, "NOTEEVENTS.csv")

In [4]:
def load_discharge_notes(csv_path):
    try:
        df = pd.read_csv(csv_path, usecols=["SUBJECT_ID", "HADM_ID", "CATEGORY", "TEXT"])
        discharge_notes = df[df["CATEGORY"] == "Discharge summary"].dropna()
        return discharge_notes
    except Exception as e:
        print(f"Error loading CSV: {e}")
        return pd.DataFrame()

In [5]:
def create_neo4j_driver():
    try:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        return driver
    except Exception as e:
        print(f"Neo4j connection error: {e}")
        return None

In [6]:
import openai
import json
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_random_exponential

client = OpenAI(api_key=openai.api_key)

@retry(stop=stop_after_attempt(3), wait=wait_random_exponential(min=1, max=5))
def extract_entities_with_openai(text):
    prompt = f"""
    You are a medical text extraction assistant. Your task is to extract key medical entities from the discharge summary provided below. The entities to extract are:

    1. **Symptoms**: Extract each symptom as an object with fields:
       - `name` (e.g., "headache")
       - `negated`: true or false depending on if the symptom was denied or not
       - `duration` (e.g., "3 days")
       - `intensity` (e.g., "mild", "severe")
       - `frequency` (e.g., "daily", "intermittent")

    2. **Medical History**: Identify any past medical conditions (e.g., "hypertension", "diabetes").

    3. **Medications**: List any medications mentioned (e.g., "aspirin", "insulin").

    4. **Family History**: Extract as a list of objects with:
       - `member` (e.g., "father", "mother")
       - `condition` (e.g., "heart disease", "cancer")

    5. **Allergies**: List any allergies (e.g., "penicillin").

    6. **Procedures**: List any procedures or surgeries (e.g., "appendectomy", "MRI").

    7. **Vitals**: List key vitals or measurements (e.g., "blood pressure 130/80", "heart rate 98").

    8. **Social History**: List any lifestyle, behavioral, or environmental details about the patient such as occupation, smoking, alcohol, drug use, living conditions, or family structure.
    
    9. **Family Medical History**: List containing only the names of medical conditions mentioned in family history.


    The output should be a **valid JSON object** with this structure:
    {{
      "Symptoms": [
        {{
          "name": "symptom1",
          "negated": false,
          "duration": "",
          "intensity": "",
          "frequency": ""
        }},
        ...
      ],
      "Medical History": ["condition1", "condition2", ...],
      "Medications": ["medication1", "medication2", ...],
      "Family History": [
        {{
          "member": "father",
          "condition": "heart disease"
        }},
        ...
      ],
      "Family Medical History": ["heart disease"],
      "Allergies": ["allergy1", "allergy2", ...],
      "Procedures": ["procedure1", "procedure2", ...],
      "Vitals": ["vital1", "vital2", ...],
      "Social History": ["history1", "history2", ...]
    }}

    Please extract the relevant entities from the following discharge summary and return **only the JSON object** without any explanations or formatting.

    Discharge Summary:
    {text}
    """

    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a helpful medical assistant."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2
        )

        raw_output = response.choices[0].message.content.strip()

        # Remove code formatting like ```json
        if raw_output.startswith("```"):
            raw_output = raw_output.strip("```json").strip("```").strip()

        return json.loads(raw_output)

    except json.JSONDecodeError as e:
        print("JSON parsing error:", e)
        print("Raw model output:\n", raw_output)
        return {}

    except Exception as e:
        print("Unexpected error during OpenAI call:", e)
        return {}


In [7]:
def refine_entities_with_self_review(text):
    raw_entities = extract_entities_with_openai(text)
    if not raw_entities:
        return {}

    review_prompt = f"""
                        You are a medical reasoning assistant tasked with reviewing and refining structured clinical entity extraction for accuracy.

                        You are given:
                        1. The original discharge summary describing the medical condition of a specific patient.
                        2. A JSON object containing previously extracted medical entities from the discharge summary.

                        ---

                        Your task is to:
                        1. **Identify the true patient** described in the discharge summary. This may be:
                           - A newborn
                           - A mother
                           - A child or adult
                           - A family member (e.g., father or sibling)

                           Carefully determine who the discharge summary is actually about and focus only on extracting and retaining information about that person.

                        2. **Review the extracted medical entities** and correct any errors, misattributions, or hallucinations. Specifically, follow these instructions:

                        - Re-read the entire discharge summary carefully.
                        - Validate whether each extracted entity truly applies to the patient, not someone else (such as the patient's mother, father, or sibling).
                        - If any entity (e.g., a symptom, history, allergy, etc.) actually refers to someone else (especially common in prenatal or neonatal cases), remove it.
                        - If an entity is mentioned in the discharge summary but clearly denied (e.g., "no fever", "no history of asthma"), ensure it is correctly labeled with `"negated": true`.
                        - Ensure "Family Medical History" contains only valid conditions (no duplicates or invented terms).
                        - If a symptom or medical fact is ambiguous or not clearly attributed to the patient, remove it to avoid hallucination.
                        - Do not invent or add any new entities. Only refine or delete based on the original context.
                        - Do not modify the structure or keys of the JSON.

                        ---

                        ### Entity Types to Review:
                        - Symptoms (especially carefully — often attributed to mothers in neonatal notes)
                        - Medical History
                        - Family History
                        - Allergies
                        - Procedures
                        - Vitals
                        - Social History

                        ---

                        ### Output Format:
                        Your response must follow this format strictly:
                        1. First line: `Identified patient: <short phrase>` (e.g., "the newborn", "the mother")
                        2. Second line onward: Only the corrected JSON object with this exact structure:

                        {{
                          "Symptoms": [
                            {{
                              "name": "symptom1",
                              "negated": false,
                              "duration": "",
                              "intensity": "",
                              "frequency": ""
                            }}
                          ],
                          "Medical History": ["condition1", "condition2"],
                          "Medications": ["medication1", "medication2"],
                          "Family History": [
                            {{
                              "member": "father",
                              "condition": "heart disease"
                            }}
                          ],
                          "Family Medical History": ["heart disease"],
                          "Allergies": ["allergy1", "allergy2"],
                          "Procedures": ["procedure1", "procedure2"],
                          "Vitals": ["vital1", "vital2"],
                          "Social History": ["history1", "history2"]
                        }}

                        Do not include explanations, formatting, or markdown.

                        ---

                        Original Extracted JSON:
                        {json.dumps(raw_entities, indent=2)}

                        ---

                        Discharge Summary:
                        {text}
                        """

    try:
        review_response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {
                    "role": "system",
                    "content": "You are a clinical reviewer validating that all extracted entities apply only to the correct patient."
                },
                {"role": "user", "content": review_prompt}
            ],
            temperature=0.2
        )

        raw_response = review_response.choices[0].message.content.strip()

        # Clean markdown if present
        if raw_response.startswith("```"):
            raw_response = raw_response.replace("```json", "").replace("```", "").strip()

        # Split to isolate JSON
        lines = raw_response.splitlines()
        json_start_idx = next((i for i, line in enumerate(lines) if line.strip().startswith("{")), 0)
        json_str = "\n".join(lines[json_start_idx:])

        return json.loads(json_str)

    except json.JSONDecodeError as e:
        print("JSON parse error during self-refine:", e)
        print("Raw response:\n", raw_response)
        return raw_entities

    except Exception as e:
        print("Error during self-refine:", e)
        return raw_entities


In [8]:
def create_patient_kg(tx, patient_id, adm_id, entities):
    # Core nodes
    tx.run("MERGE (p:Patient {id: $pid})", pid=patient_id)
    tx.run("MERGE (a:Admission {id: $aid})", aid=adm_id)

    # Patient --HAS_ADMISSION--> Admission 
    tx.run("""
        MATCH (p:Patient {id: $pid}), (a:Admission {id: $aid})
        MERGE (p)-[:HAS_ADMISSION]->(a)
    """, pid=patient_id, aid=adm_id)

    # Symptoms and associated attributes
    for sym in entities.get("Symptoms", []):
        name = sym.get("name", "").strip()
        if not name:
            continue

        # Create symptom node
        tx.run("MERGE (s:Symptom {name: $name})", name=name)

        # Link from Admission to Symptom
        rel_type = "HAS_NOSYMPTOM" if sym.get("negated") else "HAS_SYMPTOM"
        tx.run(f"""
            MATCH (a:Admission {{id: $aid}}), (s:Symptom {{name: $name}})
            MERGE (a)-[:{rel_type}]->(s)
        """, aid=adm_id, name=name)

        # Duration
        duration = sym.get("duration")
        if duration:
            tx.run("MERGE (d:Duration {value: $val})", val=duration)
            tx.run("""
                MATCH (s:Symptom {name: $name}), (d:Duration {value: $val})
                MERGE (s)-[:HAS_DURATION]->(d)
            """, name=name, val=duration)

        # Intensity
        intensity = sym.get("intensity")
        if intensity:
            tx.run("MERGE (i:Intensity {level: $val})", val=intensity)
            tx.run("""
                MATCH (s:Symptom {name: $name}), (i:Intensity {level: $val})
                MERGE (s)-[:HAS_INTENSITY]->(i)
            """, name=name, val=intensity)

        # Frequency
        frequency = sym.get("frequency")
        if frequency:
            tx.run("MERGE (f:Frequency {value: $val})", val=frequency)
            tx.run("""
                MATCH (s:Symptom {name: $name}), (f:Frequency {value: $val})
                MERGE (s)-[:HAS_FREQUENCY]->(f)
            """, name=name, val=frequency)

    # Medical History
    for h in entities.get("Medical History", []):
        tx.run("MERGE (h:History {detail: $val})", val=h)
        tx.run("""
            MATCH (p:Patient {id: $pid}), (h:History {detail: $val})
            MERGE (p)-[:HAS_MEDICAL_HISTORY]->(h)
        """, pid=patient_id, val=h)

    # Vitals
    for v in entities.get("Vitals", []):
        tx.run("MERGE (v:Vital {detail: $val})", val=v)
        tx.run("""
            MATCH (a:Admission {id: $aid}), (v:Vital {detail: $val})
            MERGE (a)-[:HAS_VITAL]->(v)
        """, aid=adm_id, val=v)

    # Allergies
    for a in entities.get("Allergies", []):
        tx.run("MERGE (al:Allergy {substance: $val})", val=a)
        tx.run("""
            MATCH (p:Patient {id: $pid}), (al:Allergy {substance: $val})
            MERGE (p)-[:HAS_ALLERGY]->(al)
        """, pid=patient_id, val=a)

    # Social History
    for sh in entities.get("Social History", []):
        tx.run("MERGE (s:SocialHistory {detail: $val})", val=sh)
        tx.run("""
            MATCH (p:Patient {id: $pid}), (s:SocialHistory {detail: $val})
            MERGE (p)-[:HAS_SOCIAL_HISTORY]->(s)
        """, pid=patient_id, val=sh)

    # Family History
    for fam in entities.get("Family History", []):
        member = fam.get("member", "").strip()
        condition = fam.get("condition", "").strip()
        if not member or not condition:
            continue

        tx.run("MERGE (f:FamilyMember {role: $role})", role=member)
        tx.run("""
            MATCH (p:Patient {id: $pid}), (f:FamilyMember {role: $role})
            MERGE (p)-[:HAS_FAMILY_MEMBER]->(f)
        """, pid=patient_id, role=member)

        tx.run("MERGE (fh:FamilyMedicalHistory {condition: $cond})", cond=condition)
        tx.run("""
            MATCH (f:FamilyMember {role: $role}), (fh:FamilyMedicalHistory {condition: $cond})
            MERGE (f)-[:HAS_MEDICAL_HISTORY]->(fh)
        """, role=member, cond=condition)


In [10]:
from pprint import pprint  

def main():
    notes_df = load_discharge_notes(NOTE_CSV_PATH)
    if notes_df.empty:
        print("No discharge summaries found.")
        return

    driver = create_neo4j_driver()
    if not driver:
        print("Neo4j driver initialization failed.")
        return

    # Sample 5 unique patients (first note for each)
    sample_df = notes_df.groupby("SUBJECT_ID").head(1).head(5)
    # sample_df = notes_df[notes_df['SUBJECT_ID'].isin(notes_df['SUBJECT_ID'].drop_duplicates().sample(5, random_state=42))].groupby('SUBJECT_ID').sample(1, random_state=42).reset_index(drop=True)


    with driver.session() as session:
        for idx, row in enumerate(sample_df.itertuples(), 1):
            patient_id = str(row.SUBJECT_ID)
            adm_id = str(row.HADM_ID)
            text = row.TEXT[:4000]  # Truncate for GPT model input

            print(f"\n--- Patient {idx}: SUBJECT_ID = {patient_id}, HADM_ID = {adm_id} ---")
            print("\nDischarge Summary:\n")
            print(text + "...\n")

            try:
                # Initial OpenAI extraction
                initial_entities = extract_entities_with_openai(text)
                print("Initial Extracted NER JSON:\n")
                pprint(initial_entities)

                # Self-refinement
                refined_entities = refine_entities_with_self_review(text)
                print("\nRefined NER JSON (After Self-Review):\n")
                pprint(refined_entities)

                # Insert refined result into KG
                if refined_entities:
                    session.execute_write(create_patient_kg, patient_id, adm_id, refined_entities)

            except Exception as e:
                print(f"Error processing note for patient {patient_id}: {e}")

    driver.close()
    print("\nAll discharge summaries processed.")


In [11]:
if __name__ == "__main__":
    main()


--- Patient 1: SUBJECT_ID = 22532, HADM_ID = 167853.0 ---

Discharge Summary:

Admission Date:  [**2151-7-16**]       Discharge Date:  [**2151-8-4**]


Service:
ADDENDUM:

RADIOLOGIC STUDIES:  Radiologic studies also included a chest
CT, which confirmed cavitary lesions in the left lung apex
consistent with infectious process/tuberculosis.  This also
moderate-sized left pleural effusion.

HEAD CT:  Head CT showed no intracranial hemorrhage or mass
effect, but old infarction consistent with past medical
history.

ABDOMINAL CT:  Abdominal CT showed lesions of
T10 and sacrum most likely secondary to osteoporosis. These can
be followed by repeat imaging as an outpatient.



                            [**First Name8 (NamePattern2) **] [**First Name4 (NamePattern1) 1775**] [**Last Name (NamePattern1) **], M.D.  [**MD Number(1) 1776**]

Dictated By:[**Hospital 1807**]
MEDQUIST36

D:  [**2151-8-5**]  12:11
T:  [**2151-8-5**]  12:21
JOB#:  [**Job Number 1808**]
...

Initial Extracted NER JSON

In [12]:
def relationship_extraction_prompt(conversation_history, text, patient_admission):
    subject_id = patient_admission['SubjectID']
    hadm_id = patient_admission['AdmissionID']
    schema = generate_schema()
    prompt = f"""

    Based on the doctor's query, first determine what the doctor is asking for. Then extract the appropriate relationship and nodes from the knowledge graph. \n
    For admissions related queries, the query should focus on "HAS_ADMISSION" relationship and "Admission" node. \n
    For patient information related queries, the query should focus on the "Patient" node. \n
    If the doctor asked about a symptom (e.g. cough, fever, etc.), the query should check if the "symptom" node and the "HAS_SYMPTOM" or "HAS_NOSYMOTOM" relationship; \n
    If the doctor asked about the duration, frequency, and intensity of a symptom, the query should first check if the symptom exist. If it exist, then check the "duration", "frequency" and "intensity" node respectively, and "HAS_DURATION", "HAS_FREQUENCY", "HAS_INTENSITY" relationship respectively. \n
    If the doctor asked about medical history, the query should check "History" node and the HAS_MEDICAL_HISTORY relationship. \n
    If the doctor asked about vitals (temperature, blood pressure etc), the query should check the "Vital" node and "HAS_VITAL" relationship. \n
    If the doctor asked about social history (smoking, alcohol consumption etc), the query should check the "SocialHistory" node and "HAS_SOCIAL_HISTORY" relationship. \n
    If the doctor aksed about family history, the query should first check the "HAS_FAMILY_MEMBER" relationship and "FamilyMember" node. Then, the query should check the "HAS_MEDICAL_HISTORY" relationship and "FamilyMedicalHistory" node associated with the "FamilyMember" node. \n 
    Output_format: Enclose your output in the following format. Do not give any explanations or reasoning, just provide the answer. For example:
    {{'Nodes': ['symptom', 'duration'], 'Relationships': ['HAS_SYMPTOM', 'HAS_DURATION']}}

    The natural language query is:
    {text}

    The previous conversation history is:
    {conversation_history}


    The Knowledge Graph Schema is:
    {schema}
    """

    return prompt

In [13]:
node_properties_query = """
            CALL apoc.meta.data()
            YIELD label, other, elementType, type, property
            WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
            WITH label AS nodeLabels, collect(property) AS properties
            RETURN {labels: nodeLabels, properties: properties} AS output
            """

rel_properties_query = """
            CALL apoc.meta.data()
            YIELD label, other, elementType, type, property
            WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
            WITH label AS nodeLabels, collect(property) AS properties
            RETURN {type: nodeLabels, properties: properties} AS output
            """

rel_query = """
            CALL apoc.meta.data()
            YIELD label, other, elementType, type, property
            WHERE type = "RELATIONSHIP" AND elementType = "node"
            RETURN {source: label, relationship: property, target: other} AS output
            """


In [14]:
def generate_schema():
    node_props = execute_cypher_query(node_properties_query)
    rel_props = execute_cypher_query(rel_properties_query)
    rels = execute_cypher_query(rel_query)
    return schema_text(node_props, rel_props, rels)


In [15]:
def schema_text(node_props, rel_props, rels):
    return f"""
        This is the schema representation of the Neo4j database.
        Node properties are the following:
        {node_props}

        Relationship properties are the following:
        {rel_props}

        Relationship point from source to target nodes
        {rels}

        Make sure to respect relationship types and directions
        """

In [16]:
def _run_cypher_query(tx, cypher_query):
        result = tx.run(cypher_query)
        return [record[0] for record in result]

In [17]:
def execute_cypher_query(cypher_query):
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    with driver.session() as session:
        result = session.execute_read(_run_cypher_query, cypher_query)
        return result

In [18]:
def cypher_query_construction_prompt(conversation_history, text, patient_admission, nodes_edges, abstraction_context = None):
        subject_id = patient_admission['SubjectID']
        hadm_id = patient_admission['AdmissionID']
        schema = generate_schema()
        prompt = f"""
        Write a cypher query to extract the requested information from the natural language query. The SUBJECT_ID is '{subject_id}', and the HADM_ID is '{hadm_id}'.
        The nodes and edges the query should focus on are {nodes_edges} \n
        Note that if the doctor's query is vague, it should be referring to the current context.\n
        The Cypher query should be case insensitive and check if the keyword is contained in any fields (no need for exact match). \n
        The Cypher query should handle fuzzy matching for keywords such as 'temperature', 'blood pressure', 'heart rate', etc., in the LABEL attribute of Vital nodes.\n
        The Cypher query should also handel matching smoke, smoking, tobacco if asked about smoking and social history; similarly for drinking, or alcohol. \n
        Only return the query as it should be executable directly, and no other text. Don't include any new line characters, or back ticks, or the word 'cypher', or square brackets, or quotes.\n
        
        The previous conversation history is:
        {conversation_history}
        
        The natural language query is:
        {text}
        
        The Knowledge Graph Schema is:
        {schema}"""

        if abstraction_context is not None:
            prompt += f"""
            The step back context is:
            {abstraction_context} 
        """

        prompt += """
        Here are a few examples of Cypher queries, you should replay SUBJECT_ID and HADM_ID based on input:\n
        Example 1: To check if the patient has seizures as a symptom, the cypher query should be: 
        MATCH (p:Patient {{SUBJECT_ID: '23709'}})
        OPTIONAL MATCH (p)-[:HAS_ADMISSION]->(a:Admission {{HADM_ID: '182203'}})-[:HAS_SYMPTOM]->(s:Symptom)
        WHERE s.name =~ '(?i).*seizure.*'
        WITH p, a, s
        OPTIONAL MATCH (p)-[:HAS_ADMISSION]->(a)-[:HAS_NOSYMPTOM]->(ns:Symptom)
        WHERE ns.name =~ '(?i).*seizure.*'
        RETURN 
        CASE 
            WHEN s IS NOT NULL THEN 'HAS seizure'
            WHEN ns IS NOT NULL THEN 'DOES NOT HAVE seizures'
            ELSE 'DONT KNOW'
        END AS status

        Example 2: To check how long has the patient had fevers as a symptom, the cypher query should be:
        MATCH (p:Patient {{SUBJECT_ID: '23709'}})
        OPTIONAL MATCH (p)-[:HAS_ADMISSION]->(a:Admission {{HADM_ID: '182203'}})-[:HAS_SYMPTOM]->(s:Symptom)
        WHERE s.name =~ '(?i).*fever.*'
        WITH p, a, s
        OPTIONAL MATCH (s)-[:HAS_DURATION]->(d:Duration)
        RETURN 
        CASE 
            WHEN s IS NULL THEN 'DOES NOT HAVE fevers'
            WHEN d IS NULL THEN 'DONT KNOW'
            ELSE d.name
        END AS fever_duration

        Example 3: To check for family history, the cupher query should be: 
        MATCH (p:Patient {{SUBJECT_ID: 23709}})-[:HAS_FAMILY_MEMBER]->(fm:FamilyMember)
        OPTIONAL MATCH (fm)-[:HAS_MEDICAL_HISTORY]->(fmh:FamilyMedicalHistory)
        RETURN fm.name AS family_member, fmh.name AS medical_history
        
        """
        return prompt

In [19]:
def clean_cypher_query(query):
        # Remove surrounding quotes
        query = query.strip('"')
        # Remove surrounding brackets
        query = query.strip('[]')
        # Remove newline characters
        query = query.replace('\\n', ' ')
        # Remove any leading or trailing whitespace characters
        query = query.strip()
        # Normalize whitespace within the query
        query = re.sub(r'\s+', ' ', query)
        return query

In [20]:
def _fetch_random_patient_admission(tx):
    query = """
    MATCH (p:Patient)-[:HAS_ADMISSION]->(a:Admission)
    WITH p, a, rand() AS random
    ORDER BY random
    LIMIT 1
    RETURN p.SUBJECT_ID AS SubjectID, a.HADM_ID AS AdmissionID
    """
    result = tx.run(query)
    return result.single()

In [43]:
def abstraction_generation_prompt(conversation_history, text):
    prompt = f"""
    You are an AI and Medical EHR expert. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to use for cypher query generation. \n 
    If the question is vague, consider the conversation history and the current context. Do not give any explanations or reasoning, just provide the answer. 
    Here are a few examples: \n
    input: Do you have fevers as a symptom? \n
    output: What symptoms does the patient has? \n
    input: Is your current temperature above 97 degrees? \n
    output: What is the patient's temperature? \n

    The current conversation history is:
    {conversation_history}
    The original query is:
    {text}
    """
    return prompt

## Rewrite Query Result Function
## This function combines the query results, and relationship to convert it to natural language
def query_result_rewrite(doctor_query, cypher_query, query_result):
    prompt = f"""
    You are a doctor's assistant. Based on the cypher_query, please structure the retrieved query results into natural language. Include all subject, relationship and object. 
    For example: \n
    doctor query: what symptoms do you have?
    cypher query: MATCH (p:Patient)-[:HAS_ADMISSION]->(a:Admission {{HADM_ID: '182203'}})
    MATCH (a)-[:HAS_SYMPTOM]->(s:Symptom)
    WHERE p.SUBJECT_ID = '23709'
    RETURN s.name AS Symptom 

    retrieved result: ['black and bloody stools', 'lightheadedness', 'shortness of breath']

    output: The patient has symptoms of black and bloody stools, lightheadedness, shortness of breath. 

    The doctor's original query is:
    {doctor_query}
    The cypher query is:
    {cypher_query}
    The retrieved results are:
    {query_result}
    """

    return prompt


## Summarization Function
def summarize_text_prompt(conversation_history, doctor_query, patient_response):
    prompt = f"""
    You are the doctor's assistent responsible for summarizing the conversation between the doctor and the patient.
    Be very brief, include the all the conversation history, doctor and patient's query and response. The last sentence should be about the current context (e.g. vital, symptom, or history).
    Write in full sentences and do not fabricate symptoms or history.
    The previous conversation is as follows:
    {conversation_history}
    The doctor has asked about the following query:
    {doctor_query}
    The patient's response to the doctor's query:
    {patient_response}
    """
    return prompt

## Rewrite Function
def rewrite_response_prompt(conversation_history, doctor_query, query_result, patient_admission, personality):
    subject_id = patient_admission['SubjectID']
    hadm_id = patient_admission['AdmissionID']
    prompt = f"""
    You are a virtual patient in an office visit. Your personality is {personality}.
    Your conversation history with the doctor is as follows:
    {conversation_history}
    The doctor has asked about the following query, focusing on the current context (e.g. vital, symptom, or history):
    {doctor_query}
    The query result is:
    {query_result}
    Based on all above information, please write your response to the doctor following your personality traits. Note that if the doctor's query is vague, it should be referring to the current context.
    If the query result is empty, return 'I don't know.' DO NOT fabricate any symptom or medical history. DO NOT add non-existent details to the response. DO NOT inclue any quotes, write in first person perspective. 
    """
    return prompt

## Checker Function
# def checker_construction_prompt( doctor_query, query_result, conversation_history):
#     """
#     This function check if the query result is appropirately answered the question, if not, the checker will rewrite the doctor's query and try to generate the cypher query again.
#     The checker will try 3 times until it stops and claim the query is not answered. and return "I don't know".
#     """
#     prompt = f"""
#     You are a doctor's assistant. You are recording and evaluating patient's responses to doctor's query.
#     The conversation history between the doctor and patient is as follows:
#     {conversation_history}
#     The doctor's query is:
#     {doctor_query}
#     The query result is:
#     {query_result}
#     Based on the above conversation, determine if the patient's response is an appropriate answer to the doctor's query.
#     If so, return 'Y' and do not return anything else; if not, rewrite the doctor's query based on the current context; only return the modified query and nothing else.
#     """
#     return prompt

def checker_construction_prompt(doctor_query, query_result, conversation_history):
        prompt = f"""
            You are a doctor's assistant. Your task is to evaluate whether the knowledge graph query result appropriately answers the doctor’s question.

            - The `doctor_query` is what the doctor asked the patient.
            - The `query_result` is the output from the knowledge graph (e.g., a symptom name, vital sign, etc.).
            - The `conversation_history` provides context for the interaction so far.

            Your instructions:
            1. If the `query_result` fully and correctly answers the `doctor_query`, reply with **only**: `Y`
            2. If it does not fully answer the question or is unclear, rewrite the doctor’s query to make it more specific or better suited for Cypher query generation.
            3. Do not return any explanation, markdown, or formatting — only return either `Y` or the rewritten query.

            ---
            Conversation history:
            {conversation_history}

            Doctor's query:
            {doctor_query}

            Query result:
            {query_result}
            """
        return prompt


In [22]:
personality_profiles = [
    ["Responsible", "Organized", "Analytical"],
    ["Anxious", "Detailed", "Inquisitive"],
    ["Optimistic", "Outgoing", "Cooperative"],
    ["Pessimistic", "Reserved", "Skeptical"],
    ["Energetic", "Impulsive", "Adventurous"],
    ["Caring", "Patient", "Empathetic"],
    ["Practical", "Stoic", "Independent"],
    ["Emotional", "Trusting", "Open"]
]

In [44]:
import openai  

client = openai.OpenAI(api_key = openai.api_key)  # Create an OpenAI client

def run_model(prompt):
    try:
        full_promt = f"{prompt}\n"
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are a medical text extraction assistant. Your task is to extract key medical entities from the discharge summary provided below. "},
                {"role": "user", "content": full_promt}
            ],
            max_tokens=500
        )

        content = response.choices[0].message.content.strip()
        return str(content)

    except Exception as e:
        print("Error:", e)
        return None


In [51]:
def session(doctor_query, conversation_history, patient_admission, personality_profile, max_token = 4096):
    if doctor_query.lower() == 'exit':
        return "Session terminated by the user."
    print("Extracting relevant nodes and edges based on query.")
    nodes_edges_query_cypher_prompt = relationship_extraction_prompt(conversation_history, doctor_query, patient_admission)
    if len(nodes_edges_query_cypher_prompt) > max_token:
        nodes_edges_query_cypher_prompt = nodes_edges_query_cypher_prompt[:max_token]
    nodes_edges_results = run_model(nodes_edges_query_cypher_prompt)
    print(f"Nodes and edges extracted: {nodes_edges_results}")
    
    ## Step 1: Construct Abstraction Query Prompt
    print("Step 1: Constructing Abstraction Cypher query prompt based on the doctor's query.")
    abstraction_query_prompt = abstraction_generation_prompt(conversation_history, doctor_query)
    if len(abstraction_query_prompt) > max_token:
        abstraction_query_prompt = abstraction_query_prompt[:max_token]
    abstraction_query_nl = run_model(abstraction_query_prompt)
    print(f"Abstraction query in natural language generated: {abstraction_query_nl}")
    
    ## Step 3: Generate Abstraction Cypher Query
    print("Constructing Cypher query prompt based on the abstraction query.")
    abstraction_query_cypher_prompt = cypher_query_construction_prompt(conversation_history, abstraction_query_nl, patient_admission, nodes_edges_results)
    if len(abstraction_query_cypher_prompt) > max_token:
        abstraction_query_cypher_prompt = abstraction_query_cypher_prompt[:max_token]
    abstraction_query_cypher = run_model(abstraction_query_cypher_prompt)
    print(f"Abstraction cypher generated: {abstraction_query_cypher}")

    ## Step 3.5: Clean Cypher Query
    abstraction_query_cypher = clean_cypher_query(abstraction_query_cypher)

    ## Step 4: Execute the generated Cypher query
    print("Step 4: Executing the generated Cypher query.")
    abstraction_result = execute_cypher_query(abstraction_query_cypher)
    if abstraction_result:
        ## Rewrite to natural language
        abstraction_result_rewrite_prompt = query_result_rewrite(abstraction_query_nl, abstraction_query_cypher, abstraction_result)
        abstract_result = run_model(abstraction_result_rewrite_prompt)

    print(f"Abstraction Query result: {abstraction_result}")

    ## Step One: Original doctor's query
    print(f"Step Zero: The doctors has asked about: {doctor_query}")
    print("Step One: Constructing Cypher query prompt based on the doctor's query.")
    cypher_query_prompt = cypher_query_construction_prompt(conversation_history, doctor_query, patient_admission, nodes_edges_results, abstraction_context=abstraction_result)

    ## Step 2.2: Construct Cypher Query
    if len(cypher_query_prompt) > max_token:
        cypher_query_prompt = cypher_query_prompt[:max_token]
        print(f"Cypher query prompt truncated to {max_token} characters.")
    cypher_query = run_model(cypher_query_prompt)
    print(f"Cypher query generated: {cypher_query}")

    ## Step 2.3: Clean Cypher Query
    cypher_query = clean_cypher_query(cypher_query)

    ## Step Three: Execute the generated Cypher query
    print("Step Three: Executing the generated Cypher query.")
    query_result = execute_cypher_query(cypher_query)
    if query_result:
        ## Rewrite to natural language
        query_result_rewrite_prompt = query_result_rewrite(doctor_query, cypher_query, query_result)
        query_result = run_model(query_result_rewrite_prompt)
    print(f"Query result: {query_result}")

    ## Step Four: Evaluate if the query properly answered the question
    for attempt in range(2):
        print(f"Attempt {attempt + 1}: Evaluating the query result.")
        checker_prompt = checker_construction_prompt(doctor_query, query_result, conversation_history)
        if len(checker_prompt) > max_token:
            checker_prompt = checker_prompt[:max_token]
            print(f"Checker prompt truncated to {max_token} characters.")
        checked_result = run_model(checker_prompt)
        print(f"Checked result: {checked_result}")

        ## If the answer is deemed appropriate, stop the loop
        if checked_result.strip() == 'Y':
            print("Checked result is appropriate. Breaking the loop.")
            break
            


        ## If the answer is deemed inappropriate, restructure the question and try again
        print("Checked result is inappropriate. Restructuring the question.")
        cypher_query_prompt = cypher_query_construction_prompt(conversation_history, checked_result, patient_admission, nodes_edges_results)
        if len(cypher_query_prompt) > max_token:
            cypher_query_prompt = cypher_query_prompt[:max_token]
            print(f"Cypher query prompt truncated to {max_token} characters.")
        cypher_query = run_model(cypher_query_prompt)
        query_result = execute_cypher_query(cypher_query)
        query_result = query_result_rewrite(doctor_query, cypher_query, query_result)
        print(f"New query result: {query_result}")
        # if not query_result or len(query_result) == 0:
        #     query_result = ["I don't know"]
        #     logging.info("No appropriate answer after restructuring. Setting query result to 'I don't know'.")
        #     break
        
        ## If after three rounds, still no appropriate answer, return "I don't know."
    if checked_result.strip() != 'Y':
        query_result = ["I don't know"]
        print("After two rounds, still no appropriate answer. Returning 'I don't know'.")

    ## Step Five: Given Query Results, generate the patient response
    print("Step Five: Generating the patient response.")
    if query_result == ["I don't know"]:
        patient_response = "I don't know"
    else:
        rewrite_prompt = rewrite_response_prompt(conversation_history, doctor_query, query_result, patient_admission, personality_profile)
        if len(rewrite_prompt) > max_token:
            rewrite_prompt = rewrite_prompt[:max_token]
            print(f"Rewrite prompt truncated to {max_token} characters.")
        patient_response = run_model(rewrite_prompt)
        print(f"Patient response generated: {patient_response}")

    ## Step Six: Update the conversation history
    print("Step Six: Updating the conversation history.")
    summarization_prompt = summarize_text_prompt(conversation_history, doctor_query, patient_response)
    if len(summarization_prompt) > max_token:
        summarization_prompt = summarization_prompt[:max_token]
        print(f"Summarization prompt truncated to {max_token} characters.")
    summarization = run_model(summarization_prompt)
    print(f"Conversation history updated: {summarization}")
    
    ## Update the conversation history based on the most recent interaction
    conversation_history = summarization
    print(f"Conversation history: {conversation_history}")

    return patient_response, conversation_history




In [54]:
patient_admission = { 'SubjectID': '26880', 'AdmissionID': '135453.0'}
max_token = 4096
doctor_query = input("Enter the doctor query")
conversation_history = f"The patient has ID {patient_admission['SubjectID']}, and the admission ID {patient_admission['AdmissionID']}"
personality_profile = ["Responsible", "Organized", "Analytical","Terse"]

Enter the doctor query what are the symptoms


In [55]:
patient_response, conversation_history = session(doctor_query, conversation_history, patient_admission, personality_profile, max_token = 4096)

Extracting relevant nodes and edges based on query.
Nodes and edges extracted: {'Nodes': ['symptom'], 'Relationships': ['HAS_SYMPTOM']}
Step 1: Constructing Abstraction Cypher query prompt based on the doctor's query.
Abstraction query in natural language generated: What symptoms does the patient have?
Constructing Cypher query prompt based on the abstraction query.
Abstraction cypher generated: MATCH (p:Patient {id: '26880'})
OPTIONAL MATCH (p)-[:HAS_ADMISSION]->(a:Admission {id: '135453.0'})-[:HAS_SYMPTOM]->(s:Symptom)
RETURN s.name AS symptom
Step 4: Executing the generated Cypher query.
Abstraction Query result: ['agitation', 'depression', 'sadness', 'attempted suicide']
Step Zero: The doctors has asked about: what are the symptoms
Step One: Constructing Cypher query prompt based on the doctor's query.
Cypher query prompt truncated to 4096 characters.
Cypher query generated: MATCH (p:Patient {id: '26880'}) OPTIONAL MATCH (p)-[:HAS_ADMISSION]->(a:Admission {id: '135453.0'})-[:HAS_SY