### **Installing The Required Dependencies**

In [None]:
!pip install nx-arangodb
!pip install langgraph-prebuilt
!pip install nx-cugraph-cu12 --extra-index-url https://pypi.nvidia.com

In [None]:
!pip install --upgrade langchain langchain-community langchain-openai langgraph

In [None]:
!pip install autogen-agentchat~=0.2

In [None]:
!pip install autogen

# **Setting Up the env variables which also enables CUGRAPH as NETWORKX backend**

In [None]:
%env NX_CUGRAPH_AUTOCONFIG=True

In [None]:
import os
os.environ["LC_ALL"] = "C.UTF-8"
os.environ["LANG"] = "C.UTF-8"

# **Importing the Required Libraries**

In [None]:
import networkx as nx
import nx_arangodb as nxadb

from arango import ArangoClient

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from random import randint
import re

from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_community.graphs import ArangoGraph
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_core.tools import tool

# **Connecting the ArangoDB database**

In [None]:
db = ArangoClient(hosts="https://0635a9c68bd8.arangodb.cloud:8529").db(username="root", password="xoIqUBS650nNHQQInV1T", verify=True)

print(db)

# **Accessing the Graph from the database and also applying some AQL to validate that the persistance was perfect**

In [None]:
G_adb = nxadb.Graph(name="medical", db=db)

print(G_adb)

In [None]:
# Print one node with attributes
print("Sample Node in G_adb:")
print(next(iter(G_adb.nodes(data=True))))

# Print one edge with attributes
print("\nSample Edge in G_adb:")
print(next(iter(G_adb.edges(data=True))))


In [None]:
node_count = db.aql.execute("RETURN LENGTH(medical_node)")
print(f"Total Nodes: {list(node_count)[0]}")

In [None]:
edge_count = db.aql.execute("RETURN LENGTH(medical_node_to_medical_node)")
print(f"Total Edges: {list(edge_count)[0]}")

# **Setting up the ArangoGraph wrapper and also initialising the LLM**

In [None]:
arango_graph = ArangoGraph(db)

In [None]:
from google.colab import userdata
os.environ["OPENAI_API_KEY"]=userdata.get('OPEN_API_KEY')

llm = ChatOpenAI(temperature=0, model_name="gpt-4o")

llm.invoke("hello!")

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "PUT IN YOUR OPENAI KEY HERE"

llm = ChatOpenAI(temperature=0, model_name="gpt-4o")

llm.invoke("hello!")

# **Defining the Required Tools**

Text to Aql to Text

In [None]:
from langchain.chat_models import ChatOpenAI
from langchain.chains import ArangoGraphQAChain

@tool
def text_to_aql_to_text(query: str):
    """Translates a Natural Language Query into AQL while ensuring correct edge types."""

    # Define allowed edge types
    edge_types = [
        "DISEASE_HAS_TREATMENT",
        "PATIENT_DIAGNOSED_WITH",
        "DISEASE_HAS_DRUG",
        "SYMPTOM_INDICATES_DISEASE",
        "DISEASE_HAS_GENE",
        "DRUG_INTERACTS_WITH",
        "PATIENT_HAS_SYMPTOM"
    ]

    # Construct LLM instructions
    instruction_prompt = (
        f"You are an expert in graph databases and AQL. Your task is to analyze the given "
        f"natural language query and determine whether an edge type from the following list "
        f"is required: {', '.join(edge_types)}.\n\n"
        "If an edge type is required, **strictly select one** from this list and use it in the AQL query.\n"
        "**Do NOT create or assume any other edge types beyond this list.**\n\n"
        "If no edge type is required, generate a standard AQL query without using any edge.\n\n"
        f"Here is the user query:\n\n{query}"
    )

    # Configure LLM
    llm = ChatOpenAI(temperature=0, model_name="gpt-4o")

    # Set up ArangoGraphQAChain
    chain = ArangoGraphQAChain.from_llm(
        llm=llm,
        graph=arango_graph,
        verbose=True,
        allow_dangerous_requests=True
    )

    # Generate and execute AQL query
    result = chain.invoke(instruction_prompt)

    return str(result["result"])


Text to NetworkX Algorithm to Text

In [None]:
import re
import networkx as nx
from langchain_openai import ChatOpenAI
from collections import Counter

@tool
def text_to_nx_algorithm_to_text(query):
    """Executes a NetworkX computation on the Medical ArangoDB Graph."""

    llm = ChatOpenAI(temperature=0, model_name="gpt-4o")

    print("1) Generating NetworkX code")

    # Generate Python NetworkX Code using LLM
    text_to_nx = llm.invoke(f"""
    You are an expert Python programmer specializing in **NetworkX and graph traversal**.

    ### **Your Task:**
    - Generate **valid Python NetworkX code** to process queries related to a **Medical Knowledge Graph {G_adb}**.
    - `G_adb` is **already defined and populated**— **DO NOT** reinitialize it (e.g., `G_adb = nx.Graph()`).
    - The graph consists of:
      - **Nodes** with `id` and `type` attributes. Types include:
        - **'Patient', 'Symptom', 'Gene', 'Drug', 'Treatment', 'Disease'**
      - **Edges** with a `type` attribute defining relationships.

    ### **Instructions for Code Generation:**
    1. **Understand the query** and determine the best approach.
    2. **If the query is node-specific** (e.g., diseases, symptoms, patients, drugs, treatments), process the relevant nodes and edges intelligently.
    3. **If the query is general** (e.g., graph structure, connectivity, centrality, shortest paths), judge all possible NetworkX methods and select the most appropriate one.
    4. **Ensure safe execution** by wrapping the logic inside `try-except`.
    5. **Assign the result to `FINAL_RESULT`**.
    6. **Include debug prints** for intermediate steps.

    ### **Query Handling:**
    ✅ **Node-Specific Queries:**
      - Find symptoms of a disease, drug effectiveness, patient symptoms, etc.
      - Use relationships like `"SYMPTOM_INDICATES_DISEASE"`, `"PATIENT_HAS_SYMPTOM"`, `"DISEASE_HAS_GENE"`, `"DRUG_INTERACTS_WITH"`, `"DISEASE_HAS_TREATMENT"`, `"PATIENT_DIAGNOSED_WITH"`, `"DISEASE_HAS_DRUG"`.
    ✅ **Graph Algorithms (General Queries):**
      - For graph-wide questions, analyze all possible NetworkX functions (connectivity, shortest paths, centrality, clustering, etc.).
      - Do **not** explicitly assume a function—first evaluate all relevant possibilities.
    ✅ **Ensure Execution Completeness:**
      - If multiple approaches are valid, return the most informative one.

    ### **Response Rules:**
    ✅ Contain **ONLY valid Python code**—no explanations or markdown.
    ✅ **Use NetworkX** functions appropriately.
    ✅ **Handle errors safely** using `try-except`.
    ✅ **Ensure FINAL_RESULT contains meaningful output (e.g., symptom names instead of node IDs).**

    ---

    **Query:** "{query}"
    **Your response should contain ONLY Python code following these rules.**  # Use the refined prompt here
            """).content

    text_to_nx_cleaned = re.sub(r"^\`\`\`python\n|\`\`\`$", "", text_to_nx, flags=re.MULTILINE).strip()

    print('-' * 10)
    print(text_to_nx_cleaned)
    print('-' * 10)

    print("\n2) Executing NetworkX code")

    global_vars = {
        "G_adb": G_adb,
        "nx": nx,
        "Counter": Counter
    }
    local_vars = {}

    try:
        exec(text_to_nx_cleaned, global_vars, local_vars)
        FINAL_RESULT = local_vars.get("FINAL_RESULT", "Execution failed")
    except Exception as e:
        print(f"EXEC ERROR: {e}")
        FINAL_RESULT = f"EXEC ERROR: {e}"

    print('-' * 10)
    print(f"FINAL_RESULT: {FINAL_RESULT}")
    print('-' * 10)

    print("3) Formulating final answer")

    nx_to_text = llm.invoke(f"""
    The **generated Python code** executed and returned: {FINAL_RESULT}.
    Based on this, generate a **concise** response.

    Your response:
    """).content

    return nx_to_text



Hybrid Query

In [None]:
import re
import json
import networkx as nx
from collections import Counter
from langchain.chat_models import ChatOpenAI
from langchain.chains import ArangoGraphQAChain
import logging
import ast



logging.basicConfig(level=logging.INFO)

@tool
def execute_hybrid_query(query):
    """Determines execution plan and processes query using AQL and NetworkX accordingly."""

    llm = ChatOpenAI(temperature=0, model_name="gpt-4o")

    # Step 1: Determine query execution strategy
    prompt = f"""
    You are an AI assistant. Given the query below, extract the **AQL query** and **NetworkX algorithm** separately.

    ### **Rules:**
    - If the query requires retrieving data from ArangoDB, generate a valid **AQL query** based on the dataset structure.
    - If the query requires graph-based computations, extract the **NetworkX algorithm** (e.g., PageRank, Shortest Path, Community Detection,betweenness).
    - Ensure **strict JSON output** with **no explanations**.

    ### **Example Format:**
    {{
      "AQL": "WITH medical_node, medical_node_to_medical_node
              FOR patient IN medical_node
                  FILTER patient.type == 'Patient'
                  FOR diagnosis IN medical_node_to_medical_node
                      FILTER diagnosis._from == patient._id AND diagnosis.type == 'PATIENT_DIAGNOSED_WITH'
                      FOR disease IN medical_node
                          FILTER disease._id == diagnosis._to AND disease.id == 'Diabetes'
                          FOR symptomEdge IN medical_node_to_medical_node
                              FILTER symptomEdge._from == patient._id AND symptomEdge.type == 'PATIENT_HAS_SYMPTOM'
                              FOR symptom IN medical_node
                                  FILTER symptom._id == symptomEdge._to
                                  RETURN symptom",
      "Algorithm": ""
    }}

    ### **Query:** "{query}"

    ### **Your response must contain only JSON:**
    """

    response = llm.invoke(prompt).content

    # Strip any markdown code block formatting if present
    response = response.strip("```json").strip("```").strip()

    try:
        execution_plan = json.loads(response)  # SAFER than eval()
    except json.JSONDecodeError:
        logging.error("Failed to parse LLM response. Raw response: %s", response)
        print("error: Invalid response from LLM.")

    aql_query = execution_plan.get("AQL", "").strip()
    nx_query = execution_plan.get("Algorithm", "").strip()
    # print(aql_query)

    # Ensure single quotes are retained
    aql_query = aql_query.replace('"', "'")

    # print("AQL Query:", aql_query)
    # print("NetworkX Algorithm:", nx_query)
    print("n1)Generating and Executing AQL query")
    # Step 2: AQL Execution
    edge_types = [
        "DISEASE_HAS_TREATMENT",
        "PATIENT_DIAGNOSED_WITH",
        "DISEASE_HAS_DRUG",
        "SYMPTOM_INDICATES_DISEASE",
        "DISEASE_HAS_GENE",
        "DRUG_INTERACTS_WITH",
        "PATIENT_HAS_SYMPTOM"
    ]

    if aql_query:
        instruction_prompt = (
            f"You are an expert in graph databases and AQL. Your task is to analyze the given "
            f"natural language query and determine whether an edge type from the following list "
            f"is required: {', '.join(edge_types)}.\n\n"
            "If an edge type is required, **strictly select one** from this list and use it in the AQL query.\n"
            "**Do NOT create or assume any other edge types beyond this list.**\n\n"
            "If no edge type is required, generate a standard AQL query without using any edge.\n\n"
            f"Here is the user query:\n\n{aql_query}"
        )

        chain = ArangoGraphQAChain.from_llm(
            llm=llm,
            graph=arango_graph,
            verbose=True,
            allow_dangerous_requests=True
        )

        try:
            aql_result = chain.invoke(instruction_prompt)["result"]
        except Exception as e:
            logging.error("AQL Execution Failed: %s", str(e))
            aql_result = "AQL Execution Failed"
    else:
        aql_result = None

    # print(aql_result)

    logging.info("Generated AQL Query: %s", aql_result)

    ###Extracting the nodes from AQL result

    # Define the system prompt
    system_prompt = """You will be given an original query and a raw_result containing an AQL result.
    Assess both and extract **only the relevant entities** dynamically.

    ### **Rules:**
    1. Identify **categories** (e.g., Symptoms, Drugs, Treatments, Patients) mentioned in both the query and raw_result.
    2. Extract **all items** under those categories.
    3. Return only a **flat list** of extracted entities (no categories, no explanations).

    ### **Example Output Format:**
    ["Blurred Vision", "Fatigue", "Frequent Urination", "Metformin", "Insulin", "Exercise"]"""

    # Invoke the LLM
    response = llm.invoke([
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Query: {query}\nRaw Result: {aql_result}"}
    ])

    # Extract response and parse as a list
    extracted_entities = json.loads(response.content)

    # Print or use extracted_entities
    print(extracted_entities)



    # Step 3: NetworkX Execution (Over the whole graph)
    print('-' * 10)
    print("\n2)Generating NetworkX code")

    NX_ALGORITHM_TEMPLATES = {
    "PageRank": "nx.pagerank(G_adb)",
    "Betweenness Centrality": "nx.betweenness_centrality(G_adb)",
    "Community Detection": "nx.algorithms.community.greedy_modularity_communities(G_adb)",
    }

    # nx_query = "PageRank"
    text_to_nx = llm.invoke(f'''
    Given the following NetworkX query:
    {nx_query}
    Generate Python code to execute it on a NetworkX graph `G_adb`
    and store the result in a variable named `nx_result`.
    Use the predefined templates{NX_ALGORITHM_TEMPLATES} for common NetworkX algorithms depending on the query before generating new ones.
    **Your response should contain ONLY Python code.**
            ''').content

    text_to_nx_cleaned = re.sub(r"^\`\`\`python\n|\`\`\`$", "", text_to_nx, flags=re.MULTILINE).strip()
    print(text_to_nx_cleaned)
    local_vars = {}
    global_vars = {"G_adb": G_adb}  # Assuming G_adb is the graph object
    exec(text_to_nx_cleaned, global_vars, local_vars)
    # nx.set_node_attributes(G_adb, pagerank_scores, "pagerank")
    nx_result = local_vars["nx_result"]
    print(nx_result)
    print("-------")

    print(aql_result)

    #Step 4: Filtering out the required nodes and formulating the final answer
    # print(aql_results)
    symptom_ids = extracted_entities
    print(symptom_ids)

    # Extract PageRank results for the symptoms in aql_results
    symptom_metric = {node_data['id']: nx_result[node]
                        for node, node_data in G_adb.nodes(data=True)
                        if node_data.get('id') in symptom_ids}
    print(symptom_metric)

    # Extract symptom IDs from aql_results

    # Find the symptom with the highest PageRank
    # max_symptom, (max_symptom_id, max_pagerank) = max(symptom_pagerank.items(), key=lambda x: x[1][1])
    max_symptom, max_metric = max(symptom_metric.items(), key=lambda x: x[1])
    # print("done2")
    final_result = {"id": max_symptom, "metric": max_metric}


    # print("Filtered Symptom PageRank:", symptom_pagerank)
    # print("Final max PageRank:", final_result)

    print("n3) Formulating final answer")

    # Step 4: Generate Final Combined Response
    final_response = llm.invoke(f'''
    You are a data scientist responsible for inferencing from the final result.

    Given the final result{final_result} and the metric name {nx_query}
    Mention which has the highest metric value and what does it represent
    Format a clear and insightful compact response summarizing the findings. Ensure it is understandable to non-technical users.
    Also you can take help of the 'id' in {aql_result} to just name the options that were available.
    ''').content
# print(final_response)

    return {
        "Final Response": final_response
    }



Extracting Subgraph

In [None]:
import re
import networkx as nx

@tool
def extract_nx_subgraph(query):
    """
    Extracts a subgraph from a NetworkX graph (converted from ArangoDB) based on an LLM-generated query.

    Parameters:
        G_adb (ArangoDB graph): The source graph from ArangoDB.
        nx_query (str): The natural language query specifying subgraph extraction criteria.
        llm: The language model used to generate extraction code.

    Returns:
        networkx.Graph: The extracted subgraph.
    """

    llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
    # Initialize a NetworkX graph
    G_nx = nx.Graph()

    # Convert ArangoDB nodes to NetworkX
    for node_key, attributes in G_adb.nodes(data=True):
        G_nx.add_node(node_key, **attributes)

    # Convert ArangoDB edges to NetworkX
    for u, v, attributes in G_adb.edges(data=True):
        G_nx.add_edge(u, v, **attributes)

    # Generate NetworkX extraction code from LLM
    valid_node_types = {'Treatment', 'Unknown', 'Symptom', 'Gene', 'Drug', 'Patient', 'Disease'}

    text_to_nx = llm.invoke(f'''
        You are an AI assistant that generates simple, executable Python code for extracting subgraphs in NetworkX.
        The user provides a query {query} specifying which nodes and edges to extract, and you return Python code (not in a function) that does the extraction.

        - The input graph is `G_nx` (do not redefine it).
        - The graph contains **only** these node types: {valid_node_types}.
        - Dont use any edge to create the subgraph just use G_nx.subgraph(nodes_to_extract).copy()
        - **Ensure that extracted node types match the exact case and spelling as provided.**
        - The code should dynamically determine which nodes and edges to extract based on the query.
        - Use `subgraph = ...` to store the extracted graph.
        - Do not include explanations, comments, or function definitions.
        - The code should be executable via `exec()`.
    ''').content

    # Clean LLM output (remove code block markers if present)
    text_to_nx_cleaned = re.sub(r"^\`\`\`python\n|\`\`\`$", "", text_to_nx, flags=re.MULTILINE).strip()

    print(text_to_nx_cleaned)
    print("-----------")

    # Extract node_types_to_extract if present
    match = re.search(r"node_types_to_extract\s*=\s*(\{.*?\})", text_to_nx_cleaned, re.DOTALL)
    extracted_var = eval(match.group(1)) if match else set()

    # Execute the generated code
    local_vars = {}
    global_vars = {"G_nx": G_nx, "node_types_to_extract": extracted_var}
    exec(text_to_nx_cleaned, global_vars, local_vars)

    subgraph = local_vars.get("subgraph", None)
    if subgraph is None:
        print("No subgraph extracted.")
        return None

    # Get a sample node (if available)
    sample_node = next(iter(subgraph.nodes(data=True)), ("No nodes in subgraph", {}))

    # Get a sample edge (if available)
    sample_edge = next(iter(subgraph.edges(data=True)), ("No edges in subgraph", {}, {}))

    # Print the summary
    print(f"""
    The extracted subgraph includes nodes and edges related to the subgraph.

    - **Nodes**: {subgraph.number_of_nodes()}
    - **Edges**: {subgraph.number_of_edges()}

    ### Sample Node
    - **ID**: {sample_node[0]}
    - **Details**: {sample_node[1]}

    ### Sample Edge
    - **From Node**: {sample_edge[0]}
    - **To Node**: {sample_edge[1]}
    - **Details**: {sample_edge[2]}
    """)

    return subgraph



Extracting Visualisation Over A Subgraph

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import re
import json
import matplotlib.patches as mpatches
@tool
def visualize_metrics(query):
    """
    Extracts a subgraph, applies the required NetworkX algorithm(s),
    and visualizes the results.
    Returns:
        dict: {
            "subgraph_query": str,
            "algorithms": list,
            "subgraph_summary": dict,
            "metrics": dict
        }
    """
    llm = ChatOpenAI(temperature=0, model_name="gpt-4o")  # Initialize the LLM

    # First LLM: Determine if a subgraph is needed & extract the algorithm
    extraction_prompt = f'''
        You are an AI assistant that determines which NetworkX algorithm should be applied
        to a graph based on the user query: "{query}".
        If a subgraph needs to be created, generate a natural language query for extracting it.

        Return the output in JSON format with the keys:
        "subgraph_query": "query to extract relevant nodes and edges",
        "algorithms": ["algorithm_1", "algorithm_2", ...].
    '''

    extraction_response = llm.invoke(extraction_prompt).content
    extraction_response_cleaned = re.sub(r"```json\n|```", "", extraction_response, flags=re.MULTILINE).strip()
    extraction_data = json.loads(extraction_response_cleaned)  # Convert JSON string to dictionary

    subgraph_query = extraction_data["subgraph_query"]
    algorithms = extraction_data["algorithms"]

    # Extract the subgraph
    subgraph = extract_nx_subgraph(subgraph_query)

    if subgraph.number_of_nodes() == 0:
        return {
            "subgraph_query": subgraph_query,
            "algorithms": algorithms,
            "subgraph_summary": {"nodes": 0, "edges": 0},
            "metrics": None,
            "message": "No nodes found in the subgraph."
        }

    # Second LLM: Generate code to apply NetworkX algorithms
    algo_prompt = f'''
    You are an AI assistant that generates Python code to apply
    NetworkX algorithms on a given subgraph. The user wants to apply: {algorithms}.

    - The input graph is `subgraph` (do not redefine it).
    - Store the computed values as attributes in `subgraph`.
    - Use the variable names based on the algorithm (e.g., `pagerank_scores`, `betweenness_scores`).
    - Do not include explanations, comments, or function definitions.
    - The code should be executable via `exec()`.
    '''

    algo_code = llm.invoke(algo_prompt).content
    algo_code_cleaned = re.sub(r"^```python\n|```$", "", algo_code, flags=re.MULTILINE).strip()

    # Execute the generated code
    exec(algo_code_cleaned, {"subgraph": subgraph})

    # Collect computed metric values
    metrics = {}
    for algo in algorithms:
        metrics[algo] = {n: subgraph.nodes[n].get(algo, None) for n in subgraph.nodes}

    # Visualization setup
    pos = nx.spring_layout(subgraph, seed=42)

    def plot_graph(metric_name, node_colors, metric_values):
        plt.figure(figsize=(14, 10))  # Increase figure size

        # Normalize node colors based on metric values
        cmap = plt.cm.Paired
        norm = plt.Normalize(vmin=min(metric_values), vmax=max(metric_values))
        node_colors = [cmap(norm(val)) for val in metric_values]

        # Draw graph without labels to prevent overlap
        nx.draw(subgraph, pos, with_labels=True, node_color=node_colors, cmap=cmap,
                edge_color="black", node_size=600, alpha=0.9, font_size=10, font_weight="bold")

        # **Fix: Dynamically determine unique colors for the legend**
        unique_values = sorted(set(metric_values))  # Get distinct metric values
        num_categories = 3  # Define number of legend categories (High, Medium, Low)
        category_labels = ["Low Centrality", "Medium Centrality", "High Centrality"]

        # Map distinct values into 3 categories evenly
        split_indices = np.linspace(0, len(unique_values) - 1, num_categories, dtype=int)
        category_values = [unique_values[i] for i in split_indices]

        # Ensure correct color mapping for legend
        legend_patches = [
            mpatches.Patch(color=cmap(norm(val)), label=label) for val, label in zip(category_values, category_labels)
        ]

        # Adjust legend position
        plt.legend(handles=legend_patches, loc="upper left", fontsize=12, frameon=True, shadow=True)

        # Title
        plt.title(f"{metric_name} (Subgraph)", fontsize=14)
        plt.show()




    # Generate plots for all algorithms
    for algo in algorithms:
        # Ensure all nodes have a valid metric value (default to 0 if missing)
        metric_values = [subgraph.nodes[n].get(algo, 0) or 0 for n in subgraph.nodes]

        # Handle case where all values are zero (to avoid normalization error)
        if max(metric_values) == min(metric_values):  # All values are the same
            print(f"Skipping visualization for {algo} as all values are zero.")
            continue

        # Normalize colors based on metric values
        cmap = plt.cm.Paired
        norm = plt.Normalize(vmin=min(metric_values), vmax=max(metric_values))  # Safe normalization
        node_colors = [cmap(norm(val)) for val in metric_values]

        # Ensure we only plot if there are meaningful values
        plot_graph(algo.replace("_", " ").title(), node_colors, metric_values)

    return {
        "subgraph_query": subgraph_query,
        "algorithms": algorithms,
        "metrics": metrics,
        "message": "Visualization completed."
    }


# **Creating The Agent Using Autogen**

In [None]:
import autogen
from langchain_openai import ChatOpenAI
from langchain.tools import tool

# ✅ Define the tools
tools = {
    "text_to_aql_to_text": text_to_aql_to_text,
    "text_to_nx_algorithm_to_text": text_to_nx_algorithm_to_text,
    "execute_hybrid_query": execute_hybrid_query,
    "extract_nx_subgraph": extract_nx_subgraph,
    "visualize_metrics": visualize_metrics
}

# ✅ System prompt (same as LangGraph)
SYSTEM_PROMPT = """You are an AI assistant that selects the best tool to analyze a graph query.
Use the following rules:
- If the query **only requires retrieving or filtering nodes/edges from ArangoDB**, use `text_to_aql_to_text`.
  - Example: "Find all patients diagnosed with Diabetes."
  - Example: "List all diseases connected to Hypertension."

- If the query **requires complex graph algorithms (PageRank, community detection, shortest path, etc.)**, use `text_to_nx_algorithm_to_text`.
  - Example: "Find the most influential disease using PageRank."
  - Example: "Detect communities of interconnected patients."

- If the query **needs both AQL (to filter data) and NetworkX (to process it)**, use `execute_hybrid_query`.
  - Example: "Identify the most influential symptom among patients with Diabetes using PageRank."
  - Example: "Find the drugs available and compute its betweenness Centrality."

- If the query **requires extracting a subgraph from NetworkX**, use `extract_nx_subgraph`.
  - Example: "Extract a subgraph of diseases, symptoms, and treatments."

- If the query **requires both extracting a subgraph and visualizing NetworkX metrics**, use `visualize_metrics`.
  - Example: "Visualize the Betweenness Centrality values for a subgraph of diseases, symptoms, and treatments."
  - Example: "Show the PageRank values for a subgraph of interconnected medical conditions."
"""

# ✅ Create the agent with the system prompt
graph_agent = autogen.AssistantAgent(
    name="GraphAssistant",
    llm_config={"model": "gpt-4o"},
    system_message=SYSTEM_PROMPT
)

# ✅ Function to process the query using Autogen
def query_graph_with_autogen(query):
    try:
        # 🔹 Ensure Autogen returns a structured response
        response = graph_agent.generate_reply(messages=[{"role": "user", "content": query}], return_messages=True)

        # 🔹 Debug: Print the raw response from Autogen
        print("\n🔍 RAW RESPONSE FROM AUTOGEN:", response)

        # 🔹 Extract last message
        if isinstance(response, list) and len(response) > 0:
            last_message = response[-1]  # Get last message if response is a list
        else:
            last_message = response  # Assume response is a single message

        # 🔹 Extract tool name using regex
        if isinstance(last_message, dict) and "content" in last_message:
            raw_tool_name = last_message["content"].strip()
        elif isinstance(last_message, str):
            raw_tool_name = last_message.strip()
        else:
            return "Error: Unexpected response format from Autogen."

        # ✅ Extract the valid tool name using regex
        match = re.search(r"\b(text_to_aql_to_text|text_to_nx_algorithm_to_text|execute_hybrid_query|extract_nx_subgraph|visualize_metrics)\b", raw_tool_name)
        if match:
            tool_name = match.group(0)  # Extract the first valid tool name
        else:
            return f"Error: Unable to extract a valid tool name from Autogen response - `{raw_tool_name}`"

        # 🔹 Validate if the selected tool exists
        if tool_name in tools:
            print(f"\n✅ Selected Tool: {tool_name}")
            return tools[tool_name](query)  # Execute the selected tool
        else:
            return f"Error: Unknown tool selected by the agent - `{tool_name}`"

    except Exception as e:
        return f"Error processing query: {e}"

# **Query Testing**

In [None]:
query_graph_with_autogen("Who is the most popular node in the Graph?Explain why")

In [None]:
print(query_graph_with_autogen("How strongly connected is the network? Used connected components."))

In [None]:
query_graph_with_autogen("Translate this natural language query into AQL: 'Find all diseases treated by drug Aspirin'")

In [None]:
print(query_graph_with_autogen("Check how drug type Ibuprofen interacts with other drugs.Use aql"))

In [None]:
print(query_graph_with_autogen("List recommended treatments for the disease Hypertension"))

In [None]:
query_graph_with_autogen("Identify the most influential symptom among patients diagnosed with Diabetes using PageRank analysis.")

In [None]:
query_graph_with_autogen("Find the drugs available and compute its betweenness Centrality.")

In [None]:
query_graph_with_autogen("Identify the most influential drugs for disease id Asthma using PageRank analysis.")

In [None]:
query_graph_with_autogen("Extract a subgraph of drugs, Disease, symptoms, and Treatments.")

In [None]:
query_graph_with_autogen("Visualize only the Betweenness values for a subgraph of diseases and treatments.")

In [None]:
query_graph_with_autogen("Visualize only the Pagerank values for a subgraph of Diseases and Symptoms.")

In [None]:
import gradio as gr

gr.Interface(fn=query_graph, inputs="text", outputs="text").launch(share=True)