# K‑MIMIC Knowledge Graph Demo (LLM‑assisted)
**Last updated:** 2025-07-29

This notebook demonstrates how to construct a knowledge graph (KG) from **K‑MIMIC** data using an **LLM‑assisted pipeline**. 


## What you will do
1. **Set up** the environment and load configurations.  
2. **Load and preprocess** K‑MIMIC data needed for the demo.  
3. **Extract entities/relations** from text using an LLM (prompt‑based).  
4. **Construct a knowledge graph** (in memory or Neo4j) from extracted triplets.  
5. **Visualize and/or export** the resulting graph and artifacts.  

## Requirements (minimal)
- Python 3.9+
- Common data libraries: `pandas`, `numpy`, `tqdm`
- Optional graph tools: `networkx`, `pyvis` (for HTML visualization)
- Optional LLM: `transformers`, `pytorch`, `bitsandbytes`, `sentencepeice` (for HuggingFace Open LLM models)

## Table of Contents
- [0. Notebook Settings](#0-notebook-settings)
- [1. Setup & Imports](#1-setup--imports)
- [2. Data Loading](#2-data-loading)
- [3. Knowledge Graph Construction & Visualization](#3-knowledge-graph-construction--visualization)
- [4. LLM-based Relation Extraction](#4-llm-based-relation-extraction)
- [5. Application of Knowledge Graph](#5-application-of-knowledge-graph)

## 0. Notebook Settings
This section defines display options and utility flags for consistent runs across environments.


In [None]:
# === Notebook Settings ===
from pathlib import Path

# Configuration for project and data directories
DATA_DIR = Path("datasets/KMIMIC")   # <-- UPDATE THIS PATH IF NEEDED

print(f"DATA_DIR:    {DATA_DIR.resolve()}")

# Display options (optional)
try:
    import pandas as pd
    pd.set_option("display.max_colwidth", 120)
    pd.set_option("display.width", 120)
except Exception:
    pass

## 1. Setup & Imports

In [None]:
# === Setup & Imports ===
# Installation
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
!pip install --quiet networkx pyvis transformers bitsandbytes accelerate sentencepiece

In [None]:
# === Setup & Imports ===
# Import Libraries

import os
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import networkx as nx
from pyvis.network import Network
import re
import json

print("Libraries imported successfully.")

In [None]:
# Setup Models and Tokenizer

# This cell configures and loads the Large Language Model (LLM) and a Named Entity Recognition (NER) model.
# We use the LLM for extracting relationships and the NER model to identify medical entities.

# --- LLM Configuration ---
# Set up cache directories for Hugging Face models
os.environ['HF_DATASETS_CACHE'] = './transformers_cache'
os.environ['TRANSFORMERS_CACHE'] = './transformers_cache'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Hugging Face access token (replace with your own if necessary)
access_token = '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'
model_id = "epfl-llm/meditron-7b" 

# Configuration for loading the model in 4-bit precision to save memory
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=bnb_config,
    token=access_token
)


## 2. Data Loading

In [None]:
# === Data Loading ===
# --- Configuration: File Paths ---
# Update these paths if your files are in a different directory.
# We will focus on a few key tables to demonstrate the concept.
DATA_DIR = 'datasets/KMIMIC/' # <-- IMPORTANT: UPDATE THIS PATH IF NEEDED
file_paths = {
    "patients": f"{DATA_DIR}patients.csv",
    "admissions": f"{DATA_DIR}admissions.csv",
    "diagnoses": f"{DATA_DIR}diagnoses_icd.csv",
    "procedures": f"{DATA_DIR}procedures_icd.csv",
    "prescriptions": f"{DATA_DIR}prescriptions.csv",
    #"labs": f"{DATA_DIR}labevents.csv"
    #"d_diagnoses": f"{DATA_DIR}diagnoses_icd.csv",
    #"d_procedures": f"{DATA_DIR}procedures_icd.csv"
}
dataframes = {}

# --- Data Loading ---
try:
    print("Loading KMIMIC tabular data...")
    for name, path in file_paths.items():
        print(f" - Loading {path}...")
        dataframes[name] = pd.read_csv(path, on_bad_lines='skip')
    print("\nAll data loaded successfully.")

except FileNotFoundError as e:
    print(f"\n--- ERROR: File Not Found ---")
    print(f"Could not find file: {e.filename}")
    print("Please ensure all required MIMIC-IV CSV files are in the specified directory.")
    print("Required files: patients.csv.gz, admissions.csv.gz, diagnoses_icd.csv.gz, procedures_icd.csv.gz, prescriptions.csv.gz, d_icd_diagnoses.csv.gz, d_icd_procedures.csv.gz")
    dataframes = {} # Clear dataframes to prevent further errors


# --- Match ICD9CM_PROC code to description for procedures table
# ICD9CM-PROC table to match code and description
# Download from: https://www.cms.gov/medicare/coding-billing/icd-10-codes/icd-9-cm-diagnosis-procedure-codes-abbreviated-and-full-code-titles
# Then save as 'icd9cm_proc.txt' in the same directory as this notebook.
# Read txt file line by line and parse
with open('icd9cm_proc.txt', 'r') as f:
    lines = f.readlines()

# Extract code (first 4 chars) and description (from 5th char onwards, stripping spaces)
data = []
for line in lines:
    code = line[:4]
    description = line[4:].strip()
    data.append([code, description])

# Create DataFrame
icd9proc = pd.DataFrame(data, columns=['code', 'description'])

# Convert ICD-9-PROC code to 4-digit format (ex. '3.20' -> '0320', '3.2' -> '0320')
def convert_code(code):
    code = "{:.2f}".format(code)
    if '.' in code:
        parts = code.split('.')
        first, second = parts
        # Pad first part with leading zero if length is 1
        if len(first) == 1:
            first = '0' + first
        # Ensure second part is two digits (if needed)
        second = second.zfill(2)
        return first + second
    else:
        # If no dot, pad to 4 digits as general rule
        return code.zfill(4)


dataframes["procedures"]["icd_code_4"] = dataframes["procedures"]['icd_code'].apply(convert_code)
dataframes["procedures"]["opname"] = dataframes["procedures"].merge(icd9proc, left_on='icd_code_4', right_on='code', how='left')['description']

## 3. Knowledge Graph Construction & Visualization

In [None]:
# === Knowledge Graph Construction ===
# To make the graph manageable for this demonstration, we will select a small sample of patients.

if dataframes:
    # --- Sample the Data ---
    num_patients_to_process = 5 # Increase for a larger, more detailed graph
    sample_patient_ids = dataframes["patients"].head(num_patients_to_process)['subject_id']
    
    # Filter all relevant tables for the sampled patients
    for name in ["patients", "admissions", "prescriptions"]:
        dataframes[name] = dataframes[name][dataframes[name]['subject_id'].isin(sample_patient_ids)]

    sample_admission_ids = dataframes["admissions"]['hadm_id']
    # Filter tables that link to admissions
    for name in ["diagnoses", "procedures"]:#, "labs"]:
        dataframes[name] = dataframes[name][dataframes[name]['hadm_id'].isin(sample_admission_ids)]

    print(f"Sampled data for {len(sample_patient_ids)} patients.")

### Knowledge Graph Construction

In [None]:
# === Knowledge Graph Construction & Visualization ===
# This cell constructs the knowledge graph by defining nodes (patients, admissions, diagnoses, etc.)
# and edges based on the relationships (foreign keys) in the tables. We use a MultiDiGraph
# to allow for multiple instances of the same event (e.g., multiple lab tests of the same type).

if dataframes:
    G = nx.MultiDiGraph() # Initialize a directed graph that can have parallel edges

    # --- Add Patient Nodes ---
    for _, row in dataframes["patients"].iterrows():
        patient_id = f"p_{row['subject_id']}"
        G.add_node(patient_id, label=f"Patient {row['subject_id']}", type='patient', gender=row['sex'], anchor_age=row['anchor_age'])

    # --- Add Admission Nodes and Patient-Admission Edges ---
    for _, row in dataframes["admissions"].iterrows():
        patient_id = f"p_{row['subject_id']}"
        admission_id = f"a_{row['hadm_id']}"
        G.add_node(admission_id, label=f"Admission {row['hadm_id']}", type='admission', admission_type=row['admission_type'], ethnicity=row['ethnicity'], insurance=row['insurance'])
        G.add_edge(patient_id, admission_id, label='HAD_ADMISSION')

    # --- Add Diagnosis Nodes and Admission-Diagnosis Edges ---
    for _, row in dataframes["diagnoses"].dropna(subset=['icd_name']).iterrows():
        admission_id = f"a_{row['hadm_id']}"
        diag_id = f"d_{row['icd_code']}_{row['icd_version']}" # Unique ID for diagnosis code
        G.add_node(diag_id, label=row['icd_name'], type='diagnosis', icd_version=row['icd_version'])
        G.add_edge(admission_id, diag_id, label='DIAGNOSED_WITH')

    # --- Add Procedure Nodes and Admission-Procedure Edges ---
    for _, row in dataframes["procedures"].dropna(subset=['icd_code']).iterrows():
        admission_id = f"a_{row['hadm_id']}"
        proc_id = f"pr_{row['icd_code']}_{row['icd_version']}" # Unique ID for procedure code
        G.add_node(proc_id, label=f"{row['opname']}", type='procedure', icd_version=row['icd_version'])
        G.add_edge(admission_id, proc_id, label='HAD_PROCEDURE')
        
    # --- Add Prescription Nodes and Patient-Prescription Edges ---
    for _, row in dataframes["prescriptions"].dropna(subset=["drug(English)"]).head(30).iterrows(): # Sample to keep graph clean
        patient_id = f"p_{row['subject_id']}"
        drug_name = str(row['drug(English)']).lower()
        if not G.has_node(drug_name):
            G.add_node(drug_name, label=drug_name.title(), type='prescription')
        G.add_edge(patient_id, drug_name, label='PRESCRIBED')

    # --- Add Lab Event Nodes and Admission-Lab Edges ---
    if False:
        for _, row in dataframes["labs"].dropna(subset=['itemid', 'value']).head(50).iterrows(): # Sample to keep graph clean
            admission_id = f"a_{row['hadm_id']}"
            if G.has_node(admission_id):
                lab_item_id = f"l_{row['itemid']}"
                # For this demo, we label lab items by their ID. A real-world application
                # would merge this with the d_labitems.csv table for friendly names.
                if not G.has_node(lab_item_id):
                    G.add_node(lab_item_id, label=f"Lab Item {row['itemid']}", type='lab_item')
                # The edge represents the actual test event and its result
                G.add_edge(admission_id, lab_item_id, label='HAD_LAB_TEST', value=row['value'], unit=row['valueuom'], flag=row.get('flag'))

    print(f"\nKnowledge Graph built successfully.")
    print(f" - Nodes: {G.number_of_nodes()}")
    print(f" - Edges: {G.number_of_edges()}")



### Visualize the Knowledge Graph

In [None]:
# === Graph Construction & Visualization ===
# This cell creates an interactive HTML visualization of the graph.
# Nodes are colored by type for clarity.

if 'G' in locals() and G.number_of_nodes() > 0:
    net = Network(notebook=True, height='800px', width='100%', bgcolor='#222222', font_color='white', cdn_resources='in_line', directed=True)

    # Define colors for different node types
    color_map = {
        'patient': '#007bff',      # Blue
        'admission': '#28a745',   # Green
        'diagnosis': '#dc3545',   # Red
        'procedure': '#ffc107',   # Yellow
        'prescription': '#6f42c1',# Purple
        'lab_item': '#17a2b8'     # Teal
    }

    # Add nodes to the Pyvis network with appropriate styling
    for node, data in G.nodes(data=True):
        node_type = data.get('type', 'N/A')
        net.add_node(node, label=data.get('label', node), color=color_map.get(node_type, '#adb5bd'), title=f"Type: {node_type}")

    # Add edges
    for source, target, data in G.edges(data=True):
        net.add_edge(source, target, label=data.get('label', ''))

    # Set physics options for a better layout
    net.set_options("""
    var options = {
      "nodes": { "font": { "size": 14 } },
      "edges": { "font": { "align": "top", "size": 12 }, "smooth": { "type": "dynamic" } },
      "physics": {
        "forceAtlas2Based": { "gravitationalConstant": -50, "centralGravity": 0.01, "springLength": 150 },
        "minVelocity": 0.75, "solver": "forceAtlas2Based"
      }
    }
    """)
    
    print("Generating interactive knowledge graph visualization...")
    display(net.show("tabular_knowledge_graph_revised.html"))
    print("Interactive graph 'tabular_knowledge_graph_revised.html' has been saved to your directory.")
else:
    print("No graph was built, so visualization is skipped.")

## 4. LLM-based Relation Extraction

### Using LLM to Extract more relationships

In [None]:
# === Data Loading ===
def extract_relationships(response: str):
    """
    Extracts a balanced JSON object from an LLM response (even if trimmed),
    cleans it (removes trailing commas), and returns the `relationships` list.
    """
    # 1) Locate the first opening brace
    start = response.find('{')
    if start == -1:
        return []

    # 2) Extract substring from that point
    substr = response[start:]

    # 3) Find the matching closing brace by counting
    depth = 0
    end_idx = None
    for i, ch in enumerate(substr):
        if ch == '{':
            depth += 1
        elif ch == '}':
            depth -= 1
            if depth == 0:
                end_idx = i + 1
                break

    # If we never closed, take everything (best-effort)
    raw_json = substr[:end_idx] if end_idx is not None else substr

    # 4) Clean up whitespace and trailing commas
    text = raw_json.strip()
    text = re.sub(r",\s*([}\]])", r"\1", text)

    # 5) Try parsing
    try:
        data = json.loads(text)
    except json.JSONDecodeError as e:
        print(f"JSON parse error: {e}")
        print("Attempted JSON:\n", text)
        return []

    # 6) Return relationships list
    rels = data.get("relationships", [])
    return rels if isinstance(rels, list) else []

### Build the Knowledge Graph from Tabular and LLM-Extracted Data

In [None]:
# === Data Loading ===
# This cell constructs the knowledge graph in two stages:
# 1. Structural Layer: Build a base graph from the explicit relationships (foreign keys) in the tables.
# 2. Semantic Layer: Use the LLM to infer and add more complex relationships between entities within the context of a single admission.

def get_llm_response(prompt):
    """
    Helper function to get a response from the loaded LLM.
    This version correctly decodes only the newly generated tokens.
    """
    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_ids_length = inputs.input_ids.shape[1] # len(inputs["input_ids"])

    # Generate a response
    outputs = model.generate(**inputs, max_new_tokens=1024, pad_token_id=tokenizer.eos_token_id)

    # Decode only the tokens that were generated after the prompt
    response_ids = outputs[0][input_ids_length:]
    return tokenizer.decode(response_ids, skip_special_tokens=True)

if dataframes:
    # Use a standard DiGraph to prevent duplicate edges between nodes
    G = nx.DiGraph()

    # --- STAGE 1: Build Structural Layer ---
    # (Same as before, building the skeleton of the graph)
    for _, row in dataframes["patients"].iterrows():
        G.add_node(f"p_{row['subject_id']}", label=f"Patient {row['subject_id']}", type='patient', gender=row['sex'], anchor_age=row['anchor_age'])
    for _, row in dataframes["admissions"].iterrows():
        G.add_node(f"a_{row['hadm_id']}", label=f"Admission {row['hadm_id']}", type='admission', admission_type=row['admission_type'], ethnicity=row['ethnicity'])
        G.add_edge(f"p_{row['subject_id']}", f"a_{row['hadm_id']}", label='HAD_ADMISSION')
    for _, row in dataframes["diagnoses"].dropna(subset=['icd_name']).iterrows():
        diag_id = f"d_{row['icd_code']}_{row['icd_version']}"
        G.add_node(diag_id, label=row['icd_name'], type='diagnosis')
        G.add_edge(f"a_{row['hadm_id']}", diag_id, label='DIAGNOSED_WITH')
    for _, row in dataframes["procedures"].dropna(subset=['icd_code']).iterrows():
        proc_id = f"pr_{row['icd_code']}_{row['icd_version']}"
        G.add_node(proc_id, label=f"Proc: {row['icd_code']}", type='procedure')
        G.add_edge(f"a_{row['hadm_id']}", proc_id, label='HAD_PROCEDURE')
    for _, row in dataframes["prescriptions"].dropna(subset=["drug(English)"]).iterrows():
        drug_name = str(row['drug(English)']).lower()
        # Use drug name as node ID for simplicity, ensuring one node per drug
        drug_node_id = drug_name
        if not G.has_node(drug_node_id):
            G.add_node(drug_node_id, label=drug_name.title(), type='prescription')
        # This edge connects patient to prescription, not admission
        G.add_edge(f"p_{row['subject_id']}", drug_node_id, label='PRESCRIBED')

    print("Stage 1 (Structural Graph) Complete.")
    print(f" - Nodes: {G.number_of_nodes()}, Edges: {G.number_of_edges()}")

    # --- STAGE 2: Add Semantic Layer with LLM ---
    print("\nStarting Stage 2: Adding semantic relationships with LLM...")
    for hadm_id in dataframes["admissions"]['hadm_id'][:1]:
        # Gather all entities for this admission (Sample 3 cases for the demonstration)
        diagnoses = dataframes["diagnoses"][dataframes["diagnoses"]['hadm_id'] == hadm_id]['icd_name'].unique().tolist()[:3]
        procedures = [f"Proc: {c}" for c in dataframes["procedures"][dataframes["procedures"]['hadm_id'] == hadm_id]['icd_code'].unique().tolist()][:3]
        drugs = dataframes["prescriptions"][dataframes["prescriptions"]['hadm_id'] == hadm_id]['drug(English)'].dropna().str.title().unique().tolist()[:3]
        
        if not diagnoses or (not procedures and not drugs):
            continue

        # This revised prompt strictly enforces a JSON output and provides structured lists of entities.
        prompt = f"""
        <s>[INST]
        <<SYS>>
        You are a clinical knowledge extraction expert. Your task is to identify plausible relationships between the provided lists of diagnoses, procedures, and drugs from a single patient admission.

        **Instructions:**
        1.  Analyze the lists below to find connections. For example, a drug might treat a diagnosis, or a procedure might be indicated by a diagnosis.
        2.  Create relationship triplets in the format `[ENTITY 1, RELATION, ENTITY 2]`.
        3.  Valid relations include `TREATS`, `MANAGES`, `IS_INDICATION_FOR`, `SIDE_EFFECT`, `DIAGNOSES`.
        4.  The output MUST be a valid JSON object with a single key "relationships", which contains a list of the identified triplets.
        5.  If no plausible relationships can be inferred, return an empty list: `{{"relationships": []}}`.
        6.  There should be maximum one relation between two nodes.
        <</SYS>>

        **Entities from Patient Admission:**

        * **Diagnoses:** {json.dumps(diagnoses)}
        * **Procedures:** {json.dumps(procedures)}
        * **Drugs:** {json.dumps(drugs)}

        You should not exceed more than 10 triplets of relationships. Answer the output in JSON format.

        **JSON Output:**
        [/INST]
        """
        
        response = get_llm_response(prompt)
        
        # Extract triplets from the LLM response with improved, more robust parsing
        try:
            data = extract_relationships(response)
            # Find the JSON object within the response string
   
            #triplets = data.get("relationships", []) # Safely get the list of triplets
            
            if isinstance(triplets, list):
                for triplet in triplets:
                    if isinstance(triplet, list) and len(triplet) == 3:
                        # Find the corresponding nodes in the graph using case-insensitive matching
                        source_label = triplet[0]
                        target_label = triplet[2]
                        
                        source_node = [n for n, d in G.nodes(data=True) if d.get('label','').lower() == source_label.lower()]
                        target_node = [n for n, d in G.nodes(data=True) if d.get('label','').lower() == target_label.lower()]
                        
                        if source_node and target_node:
                            # Add the new semantic edge
                            G.add_edge(source_node[0], target_node[0], label=triplet[1].upper(), type='semantic')
                            print(f"  + Added semantic edge: {source_label} --[{triplet[1].upper()}]--> {target_label}")
        except json.JSONDecodeError:
            # This handles the specific error where the LLM response is not valid JSON.
            print(f"  - Could not parse LLM response. The model did not return a valid JSON object.")
            pass
        except Exception as e:
            # Catch any other unexpected errors during the process
            print(f"  - An unexpected error occurred during semantic edge extraction: {e}")
            pass

    print("\nKnowledge Graph build complete (Structural + Semantic).")
    print(f" - Final Nodes: {G.number_of_nodes()}, Final Edges: {G.number_of_edges()}")

### Visualize the Knowledge Graph

In [None]:
# === Graph Construction & Visualization ===
# This cell creates an interactive HTML visualization of the graph.
# Nodes are colored by type for clarity.

if 'G' in locals() and G.number_of_nodes() > 0:
    net = Network(notebook=True, height='800px', width='100%', bgcolor='#222222', font_color='white', cdn_resources='in_line', directed=True)

    # Define colors for different node types
    color_map = {
        'patient': '#007bff',      # Blue
        'admission': '#28a745',   # Green
        'diagnosis': '#dc3545',   # Red
        'procedure': '#ffc107',   # Yellow
        'prescription': '#6f42c1',# Purple
        'lab_item': '#17a2b8'     # Teal
    }

    # Add nodes to the Pyvis network with appropriate styling
    for node, data in G.nodes(data=True):
        node_type = data.get('type', 'N/A')
        net.add_node(node, label=data.get('label', node), color=color_map.get(node_type, '#adb5bd'), title=f"Type: {node_type}")

    # Add edges
    for source, target, data in G.edges(data=True):
        net.add_edge(source, target, label=data.get('label', ''))

    # Set physics options for a better layout
    net.set_options("""
    var options = {
      "nodes": { "font": { "size": 14 } },
      "edges": { "font": { "align": "top", "size": 12 }, "smooth": { "type": "dynamic" } },
      "physics": {
        "forceAtlas2Based": { "gravitationalConstant": -50, "centralGravity": 0.01, "springLength": 150 },
        "minVelocity": 0.75, "solver": "forceAtlas2Based"
      }
    }
    """)
    
    print("Generating interactive knowledge graph visualization...")
    display(net.show("tabular_knowledge_graph_revised.html"))
    print("Interactive graph 'tabular_knowledge_graph_revised.html' has been saved to your directory.")
else:
    print("No graph was built, so visualization is skipped.")

## 5. Application of Knowledge Graph

### Advanced GraphRAG Query Demonstrations

In [None]:
# === Data Loading ===
# Cell 9: Application - GraphRAG for Clinical Questions
#
# This cell demonstrates a GraphRAG (Retrieval-Augmented Generation) system.
# 1. A user asks a question in natural language.
# 2. The LLM translates the question into a structured graph query.
# 3. A Python function executes this query against our NetworkX graph.
# 4. The LLM receives the query results and synthesizes a final, human-readable answer.

def execute_graph_query(graph, query_plan):
    """Executes a structured query plan against the NetworkX graph."""
    nodes = list(graph.nodes(data=True))
    results = []
    
    # Find nodes matching the criteria
    for node_id, properties in nodes:
        if properties.get('type') == query_plan['target_node']['type'] and query_plan['target_node']['property_value'].lower() in properties.get('label', '').lower():
            # Now find the connected nodes based on the query plan
            for neighbor in graph.predecessors(node_id):
                edge_data = graph.get_edge_data(neighbor, node_id)
                if graph.nodes[neighbor].get('type') == query_plan['connected_node']['type'] and edge_data.get('label') == query_plan['edge']['label']:
                    results.append(graph.nodes[neighbor].get('label'))
    return results

def ask_question_with_graphrag(graph, question):
    """Orchestrates the GraphRAG process."""
    print(f"User Question: \"{question}\"")
    
    # 1. Use LLM to translate NLQ to a structured query plan
    plan_instruction = f"""
    Convert the following question into a JSON query plan to be executed on a knowledge graph.
    The graph has nodes with types 'patient', 'prescription', 'diagnosis'.
    Edges have labels like 'PRESCRIBED', 'DIAGNOSED_WITH'.

    Question: "{question}"

    JSON Query Plan:
    """
    plan_prompt = f"""<s>[INST]
<<SYS>>
You are a helpful assistant that converts natural language questions into structured JSON query plans.
<</SYS>>

{plan_instruction}
[/INST]
"""
    
    response = get_llm_response(plan_prompt)
    
    try:
        query_plan_str = re.search(r'\{.*\}', response, re.DOTALL).group(0)
        query_plan = json.loads(query_plan_str)
        print(f"\nStep 1: Generated Query Plan:\n{json.dumps(query_plan, indent=2)}")
    except Exception as e:
        print(f"Could not parse query plan from LLM response: {e}")
        return

    # 2. Execute the query on the graph
    query_results = execute_graph_query(graph, query_plan)
    print(f"\nStep 2: Retrieved {len(query_results)} results from the graph.")
    if query_results:
        print("Results:", query_results)

    # 3. Use LLM to synthesize the final answer
    answer_instruction = f"""
    Based on the following information retrieved from a clinical knowledge graph, provide a concise answer to the user's original question.

    Original Question: "{question}"
    Retrieved Information: "{', '.join(query_results) if query_results else 'No information found.'}"

    Answer:
    """
    answer_prompt = f"""<s>[INST]
<<SYS>>
You are a helpful biomedical assistant that answers questions based on provided data.
<</SYS>>

{answer_instruction}
[/INST]
"""
    final_answer = get_llm_response(answer_prompt)
    
    print(f"\nStep 3: Synthesized Final Answer:\n{final_answer}")
    return final_answer

In [None]:
# === LLM-based Information Extraction ===
# This cell demonstrates how to use the `ask_question_with_graphrag` function
# to answer different types of clinical and operational questions.

if 'G' in locals() and G.number_of_nodes() > 0:
    print("--- Advanced GraphRAG Query Demonstrations ---")
    print("Running a series of example questions against the knowledge graph...")

    # --- Example 1: Patient Cohort Analysis ---
    # This query identifies a specific group of patients based on their clinical characteristics.
    # The LLM needs to create a query plan that finds 'Pneumonia' nodes, then finds connected
    # 'admission' nodes, and finally checks the 'anchor_age' property on the connected 'patient' nodes.
    # Note: The current `execute_graph_query` is simple and may not handle the age filter directly.
    # The LLM's power is in its ability to formulate a plan, even if the execution engine is basic.
    print("\n\n" + "="*50)
    print("Example 1: Patient Cohort Analysis")
    print("="*50)
    cohort_question = "Which patients older than 70 were diagnosed with Pneumonia?"
    ask_question_with_graphrag(G, cohort_question)

    # --- Example 2: Treatment Pathway and Efficacy Questions ---
    # This query leverages the semantic 'TREATS' relationships added by the LLM in Stage 2.
    # It seeks to understand common treatment patterns for a specific condition.
    print("\n\n" + "="*50)
    print("Example 2: Treatment Pathway Inquiry")
    print("="*50)
    treatment_question = "How is Atrial Fibrillation treated?"
    # For this to work well, the LLM needs to generate a query plan that looks for 'TREATS'
    # edges pointing TO the 'Atrial Fibrillation' diagnosis node.
    ask_question_with_graphrag(G, treatment_question)

    # --- Example 3: Exploratory and Complex Queries ---
    # This is a more complex, multi-hop query that the simple `execute_graph_query` function
    # cannot handle in a single pass. However, we can demonstrate how the LLM attempts to
    # break it down. An advanced implementation would use an iterative query engine.
    print("\n\n" + "="*50)
    print("Example 3: Exploratory / Multi-Hop Query")
    print("="*50)
    exploratory_question = "What are the common procedures for patients who are prescribed both Insulin and Metformin?"
    # The LLM would need to generate a complex plan:
    # 1. Find all patients prescribed Insulin.
    # 2. Find all patients prescribed Metformin.
    # 3. Find the intersection of these two patient groups.
    # 4. For that intersection, find all connected procedures.
    ask_question_with_graphrag(G, exploratory_question)

else:
    print("Knowledge graph 'G' not found. Please run the previous cells to build the graph first.")
