In [None]:
!pip install nx-arangodb

In [None]:
!nvidia-smi

In [None]:
!pip install nx-cugraph-cu12 --extra-index-url https://pypi.nvidia.com

In [None]:
!pip install --upgrade langchain langchain-community langchain-openai langgraph arango-datasets kaleido

In [91]:
# 5. Import the required modules

import networkx as nx
import nx_arangodb as nxadb

from arango import ArangoClient
from arango_datasets import Datasets

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from random import randint
import re
import json

from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_community.graphs import ArangoGraph
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_core.tools import tool

In [None]:
db = ArangoClient(hosts="https://...arangodb.cloud:...").db(username="root", password="...", verify=True)

datasets = Datasets(db)

print(datasets.dataset_info("SYNTHEA_P100"))

datasets.load("SYNTHEA_P100")

In [None]:
import os
from langchain_openai import ChatOpenAI
os.environ["OPENAI_API_KEY"] = "sk-..."
llm = ChatOpenAI(temperature=0, model_name="gpt-4o")

llm.invoke("hello!")

In [None]:
G_adb = nxadb.Graph(name="SYNTHEA_P100", db=db)
print(f"Graph loaded with {len(G_adb.nodes)} nodes and {len(G_adb.edges)} edges")

# Optional: try to init cuGraph for GPU acceleration
try:
    import nx_cugraph as nxcg
    G_cugraph = nxcg.Graph(G_adb)
    use_gpu = True
    print("GPU acceleration enabled via cuGraph")
except Exception as e:
    use_gpu = False
    print(f"GPU acceleration not available: {str(e)}")

In [95]:
@tool
def query_medical_graph(query_spec: dict):
    """
    Dynamic medical graph analysis tool that determines the optimal query technique
    and executes it. Supports AQL, NetworkX/cuGraph algorithms, or hybrid approaches.
    
    Parameters:
    - query_spec: A dictionary containing:
      - query: Natural language query about medical data
      - context: Optional additional context
      - parameters: Optional specific query parameters
      - approach: Optional preferred analysis method
    """
    # Extract query components
    query = query_spec['query']
    context = query_spec.get('context', {})
    parameters = query_spec.get('parameters', {})
    approach = query_spec.get('approach', None)
    
    # Analyze query intent to determine approach
    query_analysis = llm.invoke(f"""
    Analyze this medical query intent: "{query}"
    
    Classify this query into one of the following categories:
    1. SIMPLE_RELATIONSHIP - Direct lookups, basic traversals, simple filtering
    2. COMPLEX_PATTERN - Requires graph algorithms (centrality, community detection, path analysis)
    3. HYBRID - Requires both relationship data and complex graph analytics
    
    Return classification in JSON:
    {{
        "category": "SIMPLE_RELATIONSHIP|COMPLEX_PATTERN|HYBRID",
        "explanation": "Brief explanation why",
        "suggested_approach": "AQL|NetworkX|Hybrid"
    }}
    
    Context: {context if context else "No additional context provided"}
    """).content
    
    # Extract analysis results
    try:
        analysis_result = json.loads(re.search(r'\{.*\}', query_analysis, re.DOTALL).group())
        query_category = analysis_result.get('category', 'SIMPLE_RELATIONSHIP')
        suggested_approach = analysis_result.get('suggested_approach', 'AQL')
    except:
        query_category = 'SIMPLE_RELATIONSHIP'
        suggested_approach = 'AQL'
    
    # Override with specified approach if provided
    if approach:
        suggested_approach = approach
    
    # Execute appropriate query technique
    if suggested_approach == 'AQL' or query_category == 'SIMPLE_RELATIONSHIP':
        return execute_aql_query(query, parameters, context)
    elif suggested_approach == 'NetworkX' or query_category == 'COMPLEX_PATTERN':
        return execute_networkx_query(query, parameters, context)
    else:  # Hybrid approach
        return execute_hybrid_query(query, parameters, context)

In [98]:
def execute_aql_query(query, parameters={}, context={}):
    """Execute AQL queries via LangChain for simple relationship queries"""
    chain = ArangoGraphQAChain.from_llm(
        llm=llm,
        graph=arango_graph,
        verbose=True,
        allow_dangerous_requests=True
    )
    
    # Enhance query with medical context
    enhanced_query = f"""
    Based on the Synthea medical knowledge graph:
    
    Original query: {query}
    
    Additional context: {context if context else "None"}
    
    Parameters to consider: {parameters if parameters else "None"}
    """
    
    result = chain.invoke(enhanced_query)
    
    # Structure and analyze results
    analysis = llm.invoke(f"""
    Based on this query result from a medical database:
    
    {result['result']}
    
    Extract and structure the key information in a format useful for medical analysis.
    Focus on presenting clear, structured data highlighting any potential rare disease insights.
    
    Return a structured JSON with relevant medical fields.
    """).content
    
    try:
        structured_data = json.loads(re.search(r'\{.*\}', analysis, re.DOTALL).group())
        return {
            "result": result['result'],
            "structured_data": structured_data,
            "query_type": "AQL",
            "original_query": query
        }
    except:
        return {
            "result": result['result'],
            "query_type": "AQL",
            "original_query": query
        }

def execute_networkx_query(query, parameters={}, context={}):
    """Execute NetworkX/cuGraph algorithms for complex pattern analysis"""
    # Generate appropriate NetworkX/cuGraph code
    code_generation_prompt = f"""
    I have a NetworkX Graph `G_adb` representing a medical knowledge graph from Synthea.
    
    The query is: "{query}"
    
    Additional context: {context if context else "None"}
    Parameters: {parameters if parameters else "None"}
    
    Generate Python code using NetworkX and/or cuGraph algorithms to answer this query.
    Focus on detecting rare disease patterns, unusual symptom clusters, or atypical progressions.
    
    Your code should:
    1. Extract relevant subgraph if needed
    2. Apply appropriate graph algorithms (centrality, community detection, path analysis)
    3. Interpret results for rare disease analysis
    4. Store final answer in variable FINAL_RESULT
    
    Only provide executable Python code without explanations.
    """

    nx_code = llm.invoke(code_generation_prompt).content
    nx_code_cleaned = re.sub(r"^```python\n|```$", "", nx_code, flags=re.MULTILINE).strip()
    
    # Try to use GPU acceleration when available
    try:
        if use_gpu:
            global_vars = {"G_adb": G_adb, "G_cugraph": G_cugraph, "nx": nx, "nxcg": nxcg, "db": db}
        else:
            global_vars = {"G_adb": G_adb, "nx": nx, "db": db}
        
        local_vars = {}
        exec(nx_code_cleaned, global_vars, local_vars)
        result = local_vars.get("FINAL_RESULT", "No result was generated")
    except Exception as e:
        # Try to fix common errors
        fix_prompt = f"""
        The following NetworkX code failed with error: {str(e)}
        
        Code:
        {nx_code_cleaned}
        
        Please fix the code to address this error. Focus only on fixing the error.
        """
        
        fixed_code = llm.invoke(fix_prompt).content
        fixed_code_cleaned = re.sub(r"^```python\n|```$", "", fixed_code, flags=re.MULTILINE).strip()
        
        try:
            exec(fixed_code_cleaned, global_vars, local_vars)
            result = local_vars.get("FINAL_RESULT", "No result was generated after fixing code")
        except Exception as e2:
            return {
                "error": f"Graph analysis failed: {str(e2)}",
                "query_type": "NetworkX",
                "original_query": query
            }
    
    # Enhance the result with medical interpretation
    result_interpretation = llm.invoke(f"""
    I executed graph analytics on a medical knowledge graph to answer: "{query}"
    
    The analysis result is: {result}
    
    Please interpret this result in the context of rare disease diagnosis, explaining:
    1. What patterns or insights were found
    2. How these relate to potential rare disease diagnosis
    3. Clinical significance of these findings
    
    Return your interpretation in a structured JSON format with medical insights.
    """).content
    
    try:
        interpretation_data = json.loads(re.search(r'\{.*\}', result_interpretation, re.DOTALL).group())
        return {
            "result": result,
            "interpretation": interpretation_data,
            "query_type": "NetworkX",
            "original_query": query
        }
    except:
        return {
            "result": result,
            "interpretation": result_interpretation,
            "query_type": "NetworkX",
            "original_query": query
        }

def execute_hybrid_query(query, parameters={}, context={}):
    """Execute hybrid queries combining AQL and NetworkX for complex analysis"""
    # Determine what parts need AQL vs NetworkX
    query_decomposition = llm.invoke(f"""
    I need to analyze this medical query using a hybrid approach combining AQL and NetworkX:
    
    "{query}"
    
    Break this query into components that should be handled by:
    1. AQL - for retrieving specific data relationships
    2. NetworkX - for complex pattern analysis
    
    Provide decomposition in JSON format:
    {{
        "aql_component": "What specific data should be retrieved using AQL",
        "networkx_component": "What analysis should be performed using NetworkX",
        "integration_strategy": "How to combine the results"
    }}
    """).content
    
    try:
        decomposition = json.loads(re.search(r'\{.*\}', query_decomposition, re.DOTALL).group())
        aql_component = decomposition.get('aql_component', '')
        networkx_component = decomposition.get('networkx_component', '')
        integration_strategy = decomposition.get('integration_strategy', '')
    except:
        # Default decomposition if parsing fails
        aql_component = f"Retrieve relevant data for: {query}"
        networkx_component = f"Analyze patterns in the data for: {query}"
        integration_strategy = "Combine the results to provide insights"
    
    # Step 1: Execute AQL query to get base data
    aql_result = execute_aql_query(aql_component, parameters, context)
    
    # Step 2: Prepare NetworkX analysis with context from AQL results
    enhanced_context = {
        **context,
        "aql_results": aql_result.get("result", "")
    }
    
    # Step 3: Execute NetworkX analysis
    networkx_result = execute_networkx_query(networkx_component, parameters, enhanced_context)
    
    # Step 4: Integrate results
    integration_prompt = f"""
    I executed a hybrid analysis on a medical knowledge graph using both AQL and NetworkX:
    
    Original query: "{query}"
    
    AQL component result: {aql_result.get('result', '')}
    
    NetworkX component result: {networkx_result.get('result', '')}
    
    Integration strategy: {integration_strategy}
    
    Please integrate these results to provide a comprehensive answer.
    Focus on insights related to rare disease patterns, unusual symptoms, or diagnostic pathways.
    
    Return your integrated analysis in a structured JSON format with medical insights.
    """
    
    integrated_result = llm.invoke(integration_prompt).content
    
    try:
        integrated_data = json.loads(re.search(r'\{.*\}', integrated_result, re.DOTALL).group())
        return {
            "result": integrated_result,
            "structured_data": integrated_data,
            "aql_component": aql_result,
            "networkx_component": networkx_result,
            "query_type": "Hybrid",
            "original_query": query
        }
    except:
        return {
            "result": integrated_result,
            "aql_component": aql_result,
            "networkx_component": networkx_result,
            "query_type": "Hybrid",
            "original_query": query
        }

In [99]:
@tool
def analyze_patient_symptoms(patient_spec: dict):
    """
    Advanced patient symptom analyzer for rare disease identification.
    
    Parameters:
    - patient_spec: A dictionary containing:
      - patient_id: Optional patient ID
      - symptoms: Optional list of symptoms to analyze
      - metadata: Optional additional information
      - analysis_type: Analysis type ("rare_disease", "evidence_path", 
                       "similar_cases", "progression")
    """
    # Extract patient specification
    patient_id = patient_spec.get('patient_id')
    symptoms = patient_spec.get('symptoms', [])
    metadata = patient_spec.get('metadata', {})
    analysis_type = patient_spec.get('analysis_type', 'rare_disease')
    
    results = {}
    
    # Symptom-based analysis (no patient ID)
    if not patient_id and symptoms:
        # Convert symptoms to codes if needed
        symptom_codes = []
        for symptom in symptoms:
            if not str(symptom).isdigit():
                # Find SNOMED codes for symptom descriptions
                aql_query = f"""
                FOR obs IN observations
                    FILTER LOWER(obs.DESCRIPTION) LIKE LOWER("%{symptom}%")
                    RETURN DISTINCT obs.CODE
                """
                cursor = db.aql.execute(aql_query)
                codes = [doc for doc in cursor]
                if codes:
                    symptom_codes.extend(codes)
            else:
                symptom_codes.append(symptom)
        
        # Find rare conditions associated with these symptoms
        aql_query = f"""
        LET symptom_codes = {symptom_codes}
        
        // Find patients with these symptoms
        LET patients_with_symptoms = (
            FOR obs IN observations
                FILTER obs.CODE IN symptom_codes
                RETURN DISTINCT obs.PATIENT
        )
        
        // Find conditions these patients have
        LET conditions = (
            FOR patient IN patients_with_symptoms
                FOR cond IN conditions
                    FILTER cond.PATIENT == patient
                    COLLECT code = cond.CODE, description = cond.DESCRIPTION
                    WITH COUNT INTO count
                    SORT count ASC
                    RETURN {{
                        code: code,
                        description: description,
                        patient_count: count,
                        prevalence: count / LENGTH(patients_with_symptoms)
                    }}
        )
        
        // Return relatively rare conditions
        FOR cond IN conditions
            FILTER cond.patient_count <= 0.1 * LENGTH(patients_with_symptoms)
            SORT cond.patient_count ASC
            LIMIT 10
            RETURN cond
        """
        
        cursor = db.aql.execute(aql_query)
        results["potential_rare_conditions"] = [doc for doc in cursor]
        results["symptom_analysis"] = {
            "input_symptoms": symptoms,
            "symptom_codes_identified": symptom_codes,
            "analysis_type": "symptom-based"
        }
    
    # Patient ID-based analysis
    elif patient_id:
        # Get patient observations
        aql_query = f"""
        FOR obs IN observations
            FILTER obs.PATIENT == "{patient_id}"
            SORT obs.DATE
            RETURN {{
                code: obs.CODE,
                description: obs.DESCRIPTION,
                date: obs.DATE,
                value: obs.VALUE
            }}
        """
        cursor = db.aql.execute(aql_query)
        patient_observations = [doc for doc in cursor]
        
        # Get patient conditions
        aql_query = f"""
        FOR c IN conditions
            FILTER c.PATIENT == "{patient_id}"
            RETURN {{
                code: c.CODE,
                description: c.DESCRIPTION,
                start: c.START,
                stop: c.STOP
            }}
        """
        cursor = db.aql.execute(aql_query)
        patient_conditions = [doc for doc in cursor]
        
        results["patient_data"] = {
            "observations": patient_observations,
            "conditions": patient_conditions,
            "patient_id": patient_id
        }
        
        # Analyze based on analysis type
        if analysis_type == "evidence_path":
            # Find rarest condition if not specified
            if "condition_id" not in metadata:
                aql_query = f"""
                FOR c IN conditions
                    FILTER c.PATIENT == "{patient_id}"
                    LET condition_count = (
                        FOR c2 IN conditions
                            FILTER c2.CODE == c.CODE
                            COLLECT WITH COUNT INTO count
                            RETURN count
                    )
                    SORT condition_count[0] ASC
                    LIMIT 1
                    RETURN {{
                        code: c.CODE,
                        description: c.DESCRIPTION,
                        count: condition_count[0]
                    }}
                """
                cursor = db.aql.execute(aql_query)
                rare_condition = [doc for doc in cursor][0] if cursor else None
                
                if rare_condition:
                    metadata["condition_id"] = rare_condition["code"]
            
            # Create evidence path visualization
            if "condition_id" in metadata:
                condition_id = metadata["condition_id"]
                # Create a subgraph showing evidence path
                G = nx.DiGraph()
                
                # Get condition information
                condition_info = None
                for c in patient_conditions:
                    if str(c["code"]) == str(condition_id):
                        condition_info = c
                        break
                
                if condition_info:
                    # Add condition node
                    G.add_node(f"condition_{condition_info['code']}", 
                              label=condition_info['description'],
                              type="condition",
                              date=condition_info['start'])
                    
                    # Add observation nodes and edges
                    for obs in patient_observations:
                        if "date" in obs:
                            G.add_node(f"obs_{obs['code']}_{obs['date']}", 
                                      label=obs['description'],
                                      value=obs.get('value', 'N/A'),
                                      type="observation",
                                      date=obs['date'])
                            
                            # Connect if observation came before diagnosis
                            if obs['date'] <= condition_info['start']:
                                G.add_edge(f"obs_{obs['code']}_{obs['date']}", 
                                          f"condition_{condition_info['code']}")
                    
                    # Store graph information
                    results["evidence_path"] = {
                        "condition": condition_info,
                        "node_count": len(G.nodes()),
                        "edge_count": len(G.edges()),
                        "graph": G
                    }
                
        elif analysis_type == "similar_cases":
            # Find similar patients with rare diseases
            patient_obs_codes = [obs["code"] for obs in patient_observations]
            
            # Get all patients and their observations
            aql_query = f"""
            LET all_patients = (
                FOR p IN patients
                    RETURN DISTINCT p._key
            )
            
            FOR patient_id IN all_patients
                LET patient_obs = (
                    FOR obs IN observations
                        FILTER obs.PATIENT == patient_id
                        RETURN DISTINCT obs.CODE
                )
                
                LET rare_conditions = (
                    FOR c IN conditions
                        FILTER c.PATIENT == patient_id
                        // Count occurrences of this condition
                        LET condition_count = (
                            FOR c2 IN conditions
                                FILTER c2.CODE == c.CODE
                                COLLECT WITH COUNT INTO count
                                RETURN count
                        )
                        // Only keep conditions in <5% of patients
                        FILTER condition_count[0] <= 0.05 * LENGTH(all_patients)
                        RETURN {{
                            code: c.CODE,
                            description: c.DESCRIPTION
                        }}
                )
                
                // Only include patients with rare conditions
                FILTER LENGTH(rare_conditions) > 0
                
                // Calculate Jaccard similarity
                LET jaccard_similarity = LENGTH(INTERSECTION(patient_obs, {patient_obs_codes})) / 
                                      LENGTH(UNION(patient_obs, {patient_obs_codes}))
                                      
                SORT jaccard_similarity DESC
                LIMIT {metadata.get('top_k', 5) + 1}  // +1 for patient themselves
                
                RETURN {{
                    patient_id: patient_id,
                    similarity: jaccard_similarity,
                    rare_conditions: rare_conditions,
                    shared_observation_count: LENGTH(INTERSECTION(patient_obs, {patient_obs_codes})),
                    total_observation_count: LENGTH(patient_obs)
                }}
            """
            
            cursor = db.aql.execute(aql_query)
            similar_patients = [doc for doc in cursor]
            
            # Filter out query patient
            similar_patients = [p for p in similar_patients if p['patient_id'] != patient_id][:metadata.get('top_k', 5)]
            
            results["similar_cases"] = {
                "similar_patients": similar_patients,
                "query_patient_id": patient_id
            }
            
        elif analysis_type == "progression":
            # Analyze disease progression patterns
            timeline = []
            
            # Add observations to timeline
            for obs in patient_observations:
                if "date" in obs:
                    timeline.append({
                        "date": obs["date"],
                        "event_type": "observation",
                        "description": obs["description"],
                        "code": obs["code"],
                        "value": obs.get("value", "")
                    })
            
            # Add conditions to timeline
            for cond in patient_conditions:
                if "start" in cond:
                    timeline.append({
                        "date": cond["start"],
                        "event_type": "condition_start",
                        "description": cond["description"],
                        "code": cond["code"]
                    })
                
                if "stop" in cond and cond["stop"]:
                    timeline.append({
                        "date": cond["stop"],
                        "event_type": "condition_end",
                        "description": cond["description"],
                        "code": cond["code"]
                    })
            
            # Sort timeline by date
            timeline.sort(key=lambda x: x["date"])
            
            results["progression_analysis"] = {
                "timeline": timeline,
                "patient_id": patient_id
            }
        
        else:  # Default to rare disease analysis
            # Identify rare conditions for this patient
            aql_query = f"""
            LET all_patients = (
                FOR p IN patients
                    RETURN DISTINCT p._key
            )
            
            FOR c IN conditions
                FILTER c.PATIENT == "{patient_id}"
                // Get count of this condition across patients
                LET condition_count = (
                    FOR c2 IN conditions
                        FILTER c2.CODE == c.CODE
                        COLLECT WITH COUNT INTO count
                        RETURN count
                )
                // Calculate rarity
                LET rarity = 1 - (condition_count[0] / LENGTH(all_patients))
                // Only return rare conditions
                FILTER rarity >= 0.9
                SORT rarity DESC
                RETURN {{
                    code: c.CODE,
                    description: c.DESCRIPTION,
                    rarity: rarity,
                    patient_count: condition_count[0],
                    total_patients: LENGTH(all_patients)
                }}
            """
            
            cursor = db.aql.execute(aql_query)
            rare_conditions = [doc for doc in cursor]
            
            results["rare_disease_analysis"] = {
                "rare_conditions": rare_conditions,
                "patient_id": patient_id
            }
    
    # Enhance results with LLM insights
    analysis_prompt = f"""
    I'm analyzing medical data for potential rare disease patterns:
    
    {results}
    
    Provide a clinical interpretation in the context of rare disease diagnosis.
    Focus on:
    1. Key patterns or insights
    2. Potential rare disease indications
    3. Suggested next steps for investigation
    
    Format as JSON with structured fields for different aspects.
    """
    
    analysis_result = llm.invoke(analysis_prompt).content
    
    try:
        interpretation = json.loads(re.search(r'\{.*\}', analysis_result, re.DOTALL).group())
        results["interpretation"] = interpretation
    except:
        results["interpretation"] = analysis_result
    
    return results

In [100]:
@tool
def generate_medical_visualization(visualization_spec: dict):
    """
    High-quality medical visualization generator that adapts to data type.
    
    Parameters:
    - visualization_spec: A dictionary with:
      - data: Data to visualize
      - type: Visualization type ("network", "timeline", "heatmap", "auto")
      - title: Title for the visualization
      - context: Contextual information for visualization
    """
    # Extract visualization specification
    data = visualization_spec.get('data', {})
    viz_type = visualization_spec.get('type', 'auto')
    title = visualization_spec.get('title', 'Medical Data Visualization')
    context = visualization_spec.get('context', {})
    
    # Determine best visualization if auto
    if viz_type == 'auto':
        if isinstance(data, nx.Graph) or 'graph' in data or ('nodes' in data and 'edges' in data):
            viz_type = 'network'
        elif 'timeline' in data or any('date' in str(item) for item in data):
            viz_type = 'timeline'
        elif 'matrix' in data or 'heatmap' in data:
            viz_type = 'heatmap'
        else:
            viz_type = 'network'  # Default for medical data
    
    # Create appropriate visualization
    if viz_type == 'network':
        return create_network_visualization(data, title, context)
    elif viz_type == 'timeline':
        return create_timeline_visualization(data, title, context)
    elif viz_type == 'heatmap':
        return create_heatmap_visualization(data, title, context)
    elif viz_type == 'sankey':
        return create_sankey_visualization(data, title, context)
    else:
        return {"error": f"Unsupported visualization type: {viz_type}"}

In [101]:
import plotly.graph_objects as go

def create_network_visualization(data, title, context):
    """Create network visualization using Matplotlib instead of Plotly"""
    # Extract or create graph structure
    if isinstance(data, nx.Graph):
        G = data
    elif 'graph' in data and isinstance(data['graph'], nx.Graph):
        G = data['graph']
    elif 'nodes' in data and 'edges' in data:
        G = nx.Graph()
        for node in data['nodes']:
            G.add_node(node['id'], **{k: v for k, v in node.items() if k != 'id'})
        for edge in data['edges']:
            G.add_edge(edge['source'], edge['target'], **{k: v for k, v in edge.items() 
                                                       if k not in ['source', 'target']})
    else:
        # Create a default medical knowledge graph
        G = nx.Graph()
        
        # Add some nodes and edges based on context
        condition = context.get('condition', 'Unknown Condition')
        G.add_node('condition', label=condition, type='condition')
        
        symptoms = context.get('symptoms', ['Symptom 1', 'Symptom 2', 'Symptom 3'])
        for i, symptom in enumerate(symptoms):
            G.add_node(f'symptom_{i}', label=symptom, type='symptom')
            G.add_edge(f'symptom_{i}', 'condition', type='indicates')
    
    # Get positions for nodes using spring layout
    pos = nx.spring_layout(G, seed=42)
    
    # Create figure
    plt.figure(figsize=(12, 8))
    
    # Map node types to colors
    color_map = {
        'condition': 'red',
        'symptom': 'blue',
        'observation': 'blue',
        'medication': 'green',
        'procedure': 'purple',
        'allergy': 'orange',
        'patient': 'yellow',
        'unknown': 'gray'
    }
    
    # Group nodes by type for better visualization
    node_types = {}
    for node in G.nodes():
        node_type = G.nodes[node].get('type', 'unknown')
        if node_type not in node_types:
            node_types[node_type] = []
        node_types[node_type].append(node)
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, alpha=0.3, edge_color='gray')
    
    # Draw nodes by type with different colors
    for node_type, nodes in node_types.items():
        color = color_map.get(node_type, 'gray')
        nx.draw_networkx_nodes(G, pos, 
                              nodelist=nodes, 
                              node_color=color,
                              node_size=500,
                              alpha=0.8,
                              label=node_type)
    
    # Add labels
    node_labels = {}
    for node in G.nodes():
        if 'label' in G.nodes[node]:
            node_labels[node] = G.nodes[node]['label']
        else:
            node_labels[node] = str(node)
    
    nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10, font_color='black')
    
    # Add legend
    plt.legend()
    
    # Add title and remove axes
    plt.title(title, fontsize=15)
    plt.axis('off')
    
    # Save visualization to file
    img_path = f"{title.replace(' ', '_').lower()}.png"
    plt.savefig(img_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    return {
        "image_path": img_path,
        "node_count": len(G.nodes()),
        "edge_count": len(G.edges()),
        "visualization_type": "network"
    }
def create_timeline_visualization(data, title, context):
    """Create interactive timeline visualization for disease progression"""
    # Extract timeline data
    if 'timeline' in data:
        timeline_data = data['timeline']
    else:
        timeline_data = data
    
    # Ensure timeline_data is a list
    if not isinstance(timeline_data, list):
        return {"error": "Timeline data must be a list of events"}
    
    # Sort timeline by date if not already sorted
    timeline_data.sort(key=lambda x: x.get('date', ''))
    
    # Prepare data for plotting
    df = pd.DataFrame(timeline_data)
    
    # If dataframe doesn't have certain columns, add defaults
    if 'date' not in df.columns:
        return {"error": "Timeline data must contain 'date' field"}
    
    if 'event_type' not in df.columns:
        df['event_type'] = 'event'
        
    if 'description' not in df.columns:
        df['description'] = 'Unknown event'
    
    # Map event types to colors
    color_map = {
        'condition_start': 'red',
        'condition_end': 'orange',
        'observation': 'blue',
        'medication_start': 'green',
        'medication_end': 'lightgreen',
        'procedure': 'purple',
        'event': 'gray'
    }
    
    # Create interactive timeline using Plotly
    fig = px.scatter(df, x='date', y='event_type', 
                   color='event_type', hover_name='description',
                   color_discrete_map=color_map,
                   title=title)
    
    # Add details to hover
    hover_template = "<b>%{hovertext}</b><br>Date: %{x}<br>Type: %{y}<br>"
    if 'value' in df.columns:
        hover_template += "Value: %{customdata}<br>"
        fig.update_traces(customdata=df['value'])
    if 'code' in df.columns:
        hover_template += "Code: %{customdata}<br>"
        fig.update_traces(customdata=df['code'])
        
    fig.update_traces(hovertemplate=hover_template)
    
    # Customize layout
    fig.update_layout(
        xaxis_title="Time",
        yaxis_title="Event Type",
        legend_title="Event Type",
        template="plotly_white",
        height=600
    )
    
    # Add connecting lines for progression
    fig.update_layout(
        shapes=[
            dict(
                type="line",
                xref="x", yref="paper",
                x0=df['date'][i], y0=0, 
                x1=df['date'][i], y1=1,
                line=dict(
                    color="Gray",
                    width=1,
                    dash="dot",
                )
            )
            for i in range(len(df))
        ]
    )
    
    # Save visualization to HTML file
    output_path = f"{title.replace(' ', '_').lower()}.html"
    fig.write_html(output_path)
    
    # Also save a static image
    img_path = f"{title.replace(' ', '_').lower()}.png"
    fig.write_image(img_path)
    
    return {
        "visualization_path": output_path,
        "image_path": img_path,
        "event_count": len(timeline_data),
        "visualization_type": "timeline"
    }

def create_heatmap_visualization(data, title, context):
    """Create heatmap visualization for symptom-disease relationships"""
    # Extract heatmap data
    if 'matrix' in data:
        matrix_data = data['matrix']
        row_labels = data.get('row_labels', [])
        col_labels = data.get('col_labels', [])
    elif isinstance(data, pd.DataFrame):
        matrix_data = data.values
        row_labels = data.index.tolist()
        col_labels = data.columns.tolist()
    else:
        # Try to construct a heatmap from available data
        matrix_data = []
        row_labels = []
        col_labels = []
        
        # This is a fallback if proper matrix data isn't provided
        if context.get('symptoms') and context.get('conditions'):
            symptoms = context.get('symptoms', [])
            conditions = context.get('conditions', [])
            
            # Create an empty matrix
            matrix_data = np.zeros((len(symptoms), len(conditions)))
            row_labels = symptoms
            col_labels = conditions
            
            # Fill with random data as placeholder
            for i in range(len(symptoms)):
                for j in range(len(conditions)):
                    matrix_data[i][j] = np.random.random()
        else:
            # Complete fallback with placeholder data
            symptoms = ["Symptom 1", "Symptom 2", "Symptom 3", "Symptom 4", "Symptom 5"]
            conditions = ["Disease A", "Disease B", "Disease C", "Disease D", "Disease E"]
            
            matrix_data = np.random.rand(len(symptoms), len(conditions))
            row_labels = symptoms
            col_labels = conditions
    
    # Create heatmap visualization using Plotly
    fig = go.Figure(data=go.Heatmap(
        z=matrix_data,
        x=col_labels,
        y=row_labels,
        colorscale='Blues',
        hovertemplate='%{y} → %{x}: %{z:.3f}<extra></extra>'
    ))
    
    # Customize layout
    fig.update_layout(
        title=title,
        xaxis=dict(title='Conditions'),
        yaxis=dict(title='Symptoms'),
        template="plotly_white"
    )
    
    # Save visualization to HTML file
    output_path = f"{title.replace(' ', '_').lower()}.html"
    fig.write_html(output_path)
    
    # Also save a static image
    img_path = f"{title.replace(' ', '_').lower()}.png"
    fig.write_image(img_path)
    
    return {
        "visualization_path": output_path,
        "image_path": img_path,
        "matrix_dimensions": f"{len(matrix_data)}x{len(matrix_data[0]) if matrix_data else 0}",
        "visualization_type": "heatmap"
    }

def create_sankey_visualization(data, title, context):
    """Create Sankey diagram for patient pathways or disease progression"""
    # Prepare data for Sankey diagram
    if 'links' in data and 'nodes' in data:
        links = data['links']
        nodes = data['nodes']
    else:
        # Try to construct from pathway data if available
        links = []
        nodes = []
        node_ids = {}
        
        # Extract pathway data
        pathways = data.get('pathways', [])
        if not pathways and 'timeline' in data:
            # Convert timeline to pathways
            events = data['timeline']
            events.sort(key=lambda x: x.get('date', ''))
            
            # Group by patient if applicable
            patient_id = context.get('patient_id', 'unknown')
            pathways = [{
                'patient_id': patient_id,
                'events': events
            }]
        
        # Create nodes and links from pathways
        node_counter = 0
        for pathway in pathways:
            events = pathway.get('events', [])
            
            # Add nodes for each unique event
            for event in events:
                event_type = event.get('event_type', 'event')
                description = event.get('description', 'Unknown')
                
                # Create a unique node identifier
                node_id = f"{event_type}_{description}"
                
                if node_id not in node_ids:
                    node_ids[node_id] = node_counter
                    nodes.append({
                        'id': node_counter,
                        'name': description,
                        'type': event_type
                    })
                    node_counter += 1
            
            # Add links between sequential events
            for i in range(len(events) - 1):
                source_event = events[i]
                target_event = events[i + 1]
                
                source_id = node_ids[f"{source_event.get('event_type', 'event')}_{source_event.get('description', 'Unknown')}"]
                target_id = node_ids[f"{target_event.get('event_type', 'event')}_{target_event.get('description', 'Unknown')}"]
                
                # Check if link already exists and increment value if so
                link_exists = False
                for link in links:
                    if link['source'] == source_id and link['target'] == target_id:
                        link['value'] += 1
                        link_exists = True
                        break
                
                if not link_exists:
                    links.append({
                        'source': source_id,
                        'target': target_id,
                        'value': 1
                    })
    
    # Create Sankey diagram
    fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=[node.get('name', f"Node {node['id']}") for node in nodes],
            color=[get_color_for_type(node.get('type', 'unknown')) for node in nodes]
        ),
        link=dict(
            source=[link['source'] for link in links],
            target=[link['target'] for link in links],
            value=[link['value'] for link in links]
        )
    )])
    
    # Customize layout
    fig.update_layout(
        title=title,
        font=dict(size=12),
        template="plotly_white"
    )
    
    # Save visualization to HTML file
    output_path = f"{title.replace(' ', '_').lower()}.html"
    fig.write_html(output_path)
    
    # Also save a static image
    img_path = f"{title.replace(' ', '_').lower()}.png"
    fig.write_image(img_path)
    
    return {
        "visualization_path": output_path,
        "image_path": img_path,
        "node_count": len(nodes),
        "link_count": len(links),
        "visualization_type": "sankey"
    }

def get_color_for_type(event_type):
    """Helper function to get colors based on event type"""
    color_map = {
        'condition': 'rgba(255, 0, 0, 0.8)',
        'condition_start': 'rgba(255, 0, 0, 0.8)',
        'condition_end': 'rgba(255, 150, 0, 0.8)',
        'observation': 'rgba(0, 0, 255, 0.8)',
        'medication': 'rgba(0, 128, 0, 0.8)',
        'medication_start': 'rgba(0, 128, 0, 0.8)',
        'medication_end': 'rgba(144, 238, 144, 0.8)',
        'procedure': 'rgba(128, 0, 128, 0.8)',
        'symptom': 'rgba(0, 191, 255, 0.8)',
        'allergy': 'rgba(255, 165, 0, 0.8)',
        'patient': 'rgba(255, 255, 0, 0.8)',
        'event': 'rgba(128, 128, 128, 0.8)',
        'unknown': 'rgba(128, 128, 128, 0.8)'
    }
    return color_map.get(event_type.lower(), 'rgba(128, 128, 128, 0.8)')



In [103]:
def create_rare_disease_agent():
    """Create an agent with the custom medical tools"""
    tools = [
        query_medical_graph,
        analyze_patient_symptoms,
        generate_medical_visualization
    ]
    
    # Create the agent with LangGraph
    agent = create_react_agent(
        llm, 
        tools
    )
    
    return agent

In [104]:
rare_disease_agent = create_rare_disease_agent()

# Define helper functions to query the agent
def query_graph(query):
    """Run a query through the medical graph agent"""
    response = rare_disease_agent.invoke({"messages": [{"role": "user", "content": query}]})
    return response["messages"][-1].content

In [106]:
def create_dashboard():
    """Create an interactive dashboard for medical graph exploration"""
    # Create main layout
    header = widgets.HTML("""
    <div style="background-color: #1A5276; color: white; padding: 20px; border-radius: 5px; margin-bottom: 20px;">
        <h1 style="margin: 0;">MedGraph Explorer</h1>
        <p>Interactive Medical Graph Analysis for Rare Disease Patterns</p>
    </div>
    """)
    
    # Create tabs
    tabs = widgets.Tab()
    
    # Tab 1: Query Explorer
    query_input = widgets.Textarea(
        value='',
        placeholder='Enter your medical query here...',
        description='Query:',
        disabled=False,
        layout=widgets.Layout(width='100%', height='100px')
    )
    
    query_type = widgets.RadioButtons(
        options=['Auto', 'AQL', 'NetworkX', 'Hybrid'],
        value='Auto',
        description='Query Type:',
        disabled=False
    )
    
    example_queries = widgets.Dropdown(
        options=[
            'What are the most common symptoms of autoimmune disorders?',
            'Find patients with rare conditions and analyze their symptom patterns',
            'Identify clusters of patients with similar rare disease profiles',
            'Show the progression timeline for patients with lupus',
            'Visualize the relationship between diabetes and heart disease',
            'What are the most common comorbidities for rare diseases?'
        ],
        description='Examples:',
        disabled=False,
        layout=widgets.Layout(width='100%')
    )
    
    run_button = widgets.Button(
        description='Run Query',
        button_style='primary',
        tooltip='Click to run query',
        icon='search'
    )
    
    output_area = widgets.Output()
    
    # Tab 2: Patient Explorer
    patient_id_input = widgets.Text(
        value='',
        placeholder='Enter patient ID',
        description='Patient ID:',
        disabled=False,
        layout=widgets.Layout(width='50%')
    )
    
    analysis_type = widgets.Dropdown(
        options=['rare_disease', 'evidence_path', 'similar_cases', 'progression'],
        value='rare_disease',
        description='Analysis:',
        disabled=False,
        layout=widgets.Layout(width='50%')
    )
    
    run_patient_button = widgets.Button(
        description='Analyze Patient',
        button_style='primary',
        tooltip='Click to analyze patient',
        icon='user'
    )
    
    patient_output = widgets.Output()
    
    # Tab 3: Visualization Explorer
    viz_type = widgets.Dropdown(
        options=['network', 'timeline', 'heatmap', 'sankey', 'auto'],
        value='auto',
        description='Viz Type:',
        disabled=False,
        layout=widgets.Layout(width='50%')
    )
    
    viz_title = widgets.Text(
        value='Medical Data Visualization',
        description='Title:',
        disabled=False,
        layout=widgets.Layout(width='50%')
    )
    
    viz_data_input = widgets.Textarea(
        value='',
        placeholder='Enter JSON data for visualization or leave empty to use last query results',
        description='Data (optional):',
        disabled=False,
        layout=widgets.Layout(width='100%', height='100px')
    )
    
    run_viz_button = widgets.Button(
        description='Generate Visualization',
        button_style='primary',
        tooltip='Click to generate visualization',
        icon='chart-line'
    )
    
    viz_output = widgets.Output()
    
    # Tab 4: About & Help
    about_content = widgets.HTML("""
    <div style="padding: 20px; background-color: #f8f9fa; border-radius: 5px;">
        <h2>About MedGraph Explorer</h2>
        <p>This tool enables advanced exploration of medical data using graph analytics and AI to identify rare disease patterns.</p>
        
        <h3>Features:</h3>
        <ul>
            <li><strong>Intelligent Query Processing:</strong> Natural language queries with automatic selection of AQL or graph algorithms</li>
            <li><strong>Patient Analysis:</strong> Examine individual patients for rare conditions and symptom patterns</li>
            <li><strong>Interactive Visualizations:</strong> Network graphs, timelines, and heatmaps for medical insights</li>
            <li><strong>GPU Acceleration:</strong> Leverages NVIDIA cuGraph when available for high-performance analytics</li>
        </ul>
        
        <h3>How to Use:</h3>
        <ol>
            <li>Enter a natural language query about medical data in the <strong>Query Explorer</strong> tab</li>
            <li>Analyze specific patients using the <strong>Patient Explorer</strong> tab</li>
            <li>Create custom visualizations in the <strong>Visualization Explorer</strong> tab</li>
        </ol>
        
        <h3>Example Queries:</h3>
        <ul>
            <li>"What are the most common initial symptoms for rare autoimmune disorders?"</li>
            <li>"Find clusters of patients with similar rare disease progression"</li>
            <li>"Visualize the relationship between symptoms X and condition Y"</li>
        </ul>
        
        <p><em>Developed for the ArangoDB GraphRAG & NVIDIA cuGraph Hackathon</em></p>
    </div>
    """)
    
    # Assemble tabs
    query_tab = widgets.VBox([
        example_queries,
        query_input,
        widgets.HBox([query_type, run_button]),
        widgets.HTML("<hr>"),
        widgets.HTML("<h3>Results:</h3>"),
        output_area
    ])
    
    patient_tab = widgets.VBox([
        widgets.HBox([patient_id_input, analysis_type]),
        run_patient_button,
        widgets.HTML("<hr>"),
        widgets.HTML("<h3>Patient Analysis:</h3>"),
        patient_output
    ])
    
    viz_tab = widgets.VBox([
        widgets.HBox([viz_type, viz_title]),
        viz_data_input,
        run_viz_button,
        widgets.HTML("<hr>"),
        widgets.HTML("<h3>Visualization:</h3>"),
        viz_output
    ])
    
    about_tab = widgets.VBox([about_content])
    
    # Set up tabs
    tabs.children = [query_tab, patient_tab, viz_tab, about_tab]
    tabs.set_title(0, 'Query Explorer')
    tabs.set_title(1, 'Patient Explorer')
    tabs.set_title(2, 'Visualization')
    tabs.set_title(3, 'About & Help')
    
    # Store for results for later use
    latest_results = {'query': None, 'patient': None, 'viz': None}
    
    # Set up event handlers
    def on_example_select(change):
        if change['type'] == 'change' and change['name'] == 'value':
            query_input.value = change['new']
    
    def on_run_button_click(b):
        with output_area:
            clear_output()
            print(f"Running query: {query_input.value}")
            print(f"Analysis type: {query_type.value}")
            print("Processing...")
            
            # Prepare query parameters
            query_params = {
                'query': query_input.value,
                'approach': None if query_type.value == 'Auto' else query_type.value.lower()
            }
            
            try:
                # Run query through agent
                result = rare_disease_agent.invoke({
                    "messages": [{
                        "role": "user", 
                        "content": f"Query the medical graph: {json.dumps(query_params)}"
                    }]
                })
                
                # Store result for later use
                latest_results['query'] = result["messages"][-1].content
                
                # Display result
                clear_output()
                display(HTML(f"<h3>Query Results:</h3><p>{result['messages'][-1].content}</p>"))
                
                # Try to extract and display visualization if available
                try:
                    result_text = result["messages"][-1].content
                    if "visualization_path" in result_text:
                        viz_path_match = re.search(r'"visualization_path":\s*"([^"]+)"', result_text)
                        if viz_path_match and os.path.exists(viz_path_match.group(1)):
                            viz_path = viz_path_match.group(1)
                            if viz_path.endswith('.html'):
                                display(HTML(f'<iframe src="{viz_path}" width="100%" height="600px"></iframe>'))
                            else:
                                display(HTML(f'<img src="{viz_path}" style="max-width:100%;">'))
                except Exception as e:
                    print(f"Error displaying visualization: {e}")
            
            except Exception as e:
                clear_output()
                print(f"Error running query: {e}")
    
    def on_run_patient_button_click(b):
        with patient_output:
            clear_output()
            
            patient_id = patient_id_input.value.strip()
            if not patient_id:
                print("Please enter a valid patient ID")
                return
            
            print(f"Analyzing patient: {patient_id}")
            print(f"Analysis type: {analysis_type.value}")
            print("Processing...")
            
            # Prepare patient analysis parameters
            patient_params = {
                'patient_id': patient_id,
                'analysis_type': analysis_type.value
            }
            
            try:
                # Run patient analysis through agent
                result = rare_disease_agent.invoke({
                    "messages": [{
                        "role": "user", 
                        "content": f"Analyze patient symptoms: {json.dumps(patient_params)}"
                    }]
                })
                
                # Store result for later use
                latest_results['patient'] = result["messages"][-1].content
                
                # Display result
                clear_output()
                display(HTML(f"<h3>Patient Analysis:</h3><p>{result['messages'][-1].content}</p>"))
                
                # Try to extract and display visualization if available
                try:
                    result_text = result["messages"][-1].content
                    if "visualization_path" in result_text or "image_path" in result_text:
                        viz_path_match = re.search(r'"(visualization_path|image_path)":\s*"([^"]+)"', result_text)
                        if viz_path_match and os.path.exists(viz_path_match.group(2)):
                            viz_path = viz_path_match.group(2)
                            if viz_path.endswith('.html'):
                                display(HTML(f'<iframe src="{viz_path}" width="100%" height="600px"></iframe>'))
                            else:
                                display(HTML(f'<img src="{viz_path}" style="max-width:100%;">'))
                except Exception as e:
                    print(f"Error displaying visualization: {e}")
            
            except Exception as e:
                clear_output()
                print(f"Error analyzing patient: {e}")
    
    def on_run_viz_button_click(b):
        with viz_output:
            clear_output()
            
            print(f"Generating visualization: {viz_title.value}")
            print(f"Visualization type: {viz_type.value}")
            print("Processing...")
            
            # Prepare visualization parameters
            viz_params = {
                'type': viz_type.value,
                'title': viz_title.value
            }
            
            # Use provided data or last query results
            if viz_data_input.value.strip():
                try:
                    viz_params['data'] = json.loads(viz_data_input.value)
                except:
                    clear_output()
                    print("Error: Invalid JSON data")
                    return
            else:
                # Use most recent results
                viz_params['data'] = {
                    'latest_query': latest_results['query'],
                    'latest_patient': latest_results['patient'],
                    'latest_viz': latest_results['viz']
                }
            
            try:
                # Generate visualization through agent
                result = rare_disease_agent.invoke({
                    "messages": [{
                        "role": "user", 
                        "content": f"Generate medical visualization: {json.dumps(viz_params)}"
                    }]
                })
                
                # Store result for later use
                latest_results['viz'] = result["messages"][-1].content
                
                # Display result
                clear_output()
                display(HTML(f"<h3>Visualization:</h3><p>{result['messages'][-1].content}</p>"))
                
                # Try to extract and display visualization
                try:
                    result_text = result["messages"][-1].content
                    if "visualization_path" in result_text or "image_path" in result_text:
                        viz_path_match = re.search(r'"(visualization_path|image_path)":\s*"([^"]+)"', result_text)
                        if viz_path_match and os.path.exists(viz_path_match.group(2)):
                            viz_path = viz_path_match.group(2)
                            if viz_path.endswith('.html'):
                                display(HTML(f'<iframe src="{viz_path}" width="100%" height="600px"></iframe>'))
                            else:
                                display(HTML(f'<img src="{viz_path}" style="max-width:100%;">'))
                except Exception as e:
                    print(f"Error displaying visualization: {e}")
            
            except Exception as e:
                clear_output()
                print(f"Error generating visualization: {e}")
    
    # Connect event handlers
    example_queries.observe(on_example_select, names='value')
    run_button.on_click(on_run_button_click)
    run_patient_button.on_click(on_run_patient_button_click)
    run_viz_button.on_click(on_run_viz_button_click)
    
    # Assemble dashboard
    dashboard = widgets.VBox([header, tabs])
    return dashboard

In [None]:
dashboard = create_dashboard()
display(dashboard)

In [107]:
def generate_visualization_directly(viz_spec):
    """Helper function to call the visualization tool directly without tool validation"""
    data = viz_spec.get('data', {})
    viz_type = viz_spec.get('type', 'auto')
    title = viz_spec.get('title', 'Medical Data Visualization')
    context = viz_spec.get('context', {})
    
    # Determine best visualization if auto
    if viz_type == 'auto':
        if isinstance(data, nx.Graph) or 'graph' in data or ('nodes' in data and 'edges' in data):
            viz_type = 'network'
        elif 'timeline' in data or any('date' in str(item) for item in data):
            viz_type = 'timeline'
        elif 'matrix' in data or 'heatmap' in data:
            viz_type = 'heatmap'
        else:
            viz_type = 'network'  # Default for medical data
    
    # Create appropriate visualization
    if viz_type == 'network':
        return create_network_visualization(data, title, context)
    elif viz_type == 'timeline':
        # Fallback to simple timeline if needed
        try:
            return create_timeline_visualization(data, title, context)
        except:
            return {"error": "Timeline visualization failed. Try using a simpler format."}
    elif viz_type == 'heatmap':
        # Fallback to simple heatmap if needed
        try:
            return create_heatmap_visualization(data, title, context)
        except:
            return {"error": "Heatmap visualization failed. Try using a simpler format."}
    else:
        return {"error": f"Unsupported visualization type: {viz_type}"}

In [None]:
viz_spec = {
    "type": "network",
    "title": "Rare Disease Relationship Network",
    "context": {
        "condition": "Rare Autoimmune Disorder",
        "symptoms": ["Joint Pain", "Fatigue", "Rash", "Fever", "Weight Loss"]
    }
}
visualization_result = generate_visualization_directly(viz_spec)
print(json.dumps(visualization_result, indent=2))