#  Flight Graph Insights: A Divide-and-Conquer Strategy

This notebook demonstrates a workflow to analyze flight data using a graph-based approach with ArangoDB, NetworkX, LangChain, and LangGraph. 

The notebook is divided into several sections:
1. Installation of required packages.
2. Importing and configuring packages.
3. Connecting to the ArangoDB database and preparing the flight dataset.
4. Persisting and materializing the graph.
5. Building a query agent using LangChain & LangGraph.
6. Building a user interface using Gradio.

<p align="center">
    <img src="https://raw.githubusercontent.com/SivaTSS/flight_graph_agent/main/images/intro_flights.png" style="height: 500px;">
</p>

## Install Required Packages

In [1]:
# # 1. Install nx-arangodb via pip
# # Github: https://github.com/arangodb/nx-arangodb

# !pip install nx-arangodb

In [2]:
# # 2. Check if you have an NVIDIA GPU
# # Note: If this returns "command not found", then GPU-based algorithms via cuGraph are unavailable

# !nvidia-smi
# !nvcc --version

In [3]:
# # 3. Install nx-cugraph via pip
# # Note: Only enable this installation if the step above is working!

# !pip install nx-cugraph-cu12 --extra-index-url https://pypi.nvidia.com # Requires CUDA-capable GPU

In [4]:
# # 4. Install LangChain & LangGraph

# !pip install --upgrade langchain langchain-community langchain-openai langgraph

In [5]:
# !pip install pandas
# !pip install matplotlib
# !pip install networkx

In [6]:
# ! pip install arango_datasets


---
## Import Packages and Configure

In [7]:
# Import the required modules

import networkx as nx
import nx_arangodb as nxadb

from arango import ArangoClient

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from random import randint
import re

from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_community.graphs import ArangoGraph
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_core.tools import tool

[22:03:23 -0400] [INFO]: NetworkX-cuGraph is unavailable: No module named 'cupy'.


In [8]:
import json

# Load credentials from the JSON file
file_path="../keys.json"
with open(file_path, 'r', encoding='utf-8') as file:
    key_data = json.load(file)


In [9]:
# Connect to the ArangoDB database 

db = ArangoClient(hosts=key_data["arangodb_host"]).db(username="root", password=key_data["arangodb_password"], verify=True)

print(db)

<StandardDatabase _system>


---
## Prepare flights dataset

### About the dataset:
This dataset represents a comprehensive record of flights and airports, organized as a graph in ArangoDB. It is divided into two main collections: airports and flights.

### Graph Schema
**Graph Name:** FLIGHTS
#### Edge Definitions:
- **Edge Collection:** flights
- **From Vertex Collection:** airports
- **To Vertex Collection:** airports

This graph schema indicates that each flight (an edge) connects two airports (nodes). By modeling the data this way, we can easily explore relationships and connectivity between airports.



---
### Collection Schema Details

#### Airports Collection

- **Type**: Document  
- **Key Attributes**:
  - `_key`, `_id`, `_rev`: Standard document identifiers.
  - `name`, `city`, `state`, `country`: Provide detailed location information for each airport.
  - `lat`, `long`: Geographic coordinates used for mapping and spatial analysis.
  - `vip`: A Boolean flag that might indicate whether the airport has special status or features.

##### Example document from the airports collection:

```json
{
  "_key": "00M",
  "_id": "airports/00M",
  "_rev": "_jVk9JKu---",
  "name": "Thigpen ",
  "city": "Bay Springs",
  "state": "MS",
  "country": "USA",
  "lat": 31.95376472,
  "long": -89.23450472,
  "vip": false
}
```

---

#### Flights Collection

- **Type**: Edge  
- **Key Attributes**:
  - `_key`, `_id`, `_rev`: Standard edge identifiers.
  - `_from`, `_to`: References to the departure and arrival airports in the `airports` collection.
- **Flight Details**:
  - `Year`, `Month`, `Day`, `DayOfWeek`: Date information for the flight.
  - `DepTime`, `ArrTime`, `DepTimeUTC`, `ArrTimeUTC`: Time information, including both local and UTC times.
  - `UniqueCarrier`, `FlightNum`, `TailNum`: Airline and flight-specific identifiers.
  - `Distance`: The flight distance, which can be used for further analysis such as calculating fuel efficiency or route optimization.

##### Example edge from the flights collection:

```json
{
  "_key": "306520629",
  "_id": "flights/306520629",
  "_from": "airports/ATL",
  "_to": "airports/CHS",
  "_rev": "_jVk9djm---",
  "Year": 2008,
  "Month": 1,
  "Day": 1,
  "DayOfWeek": 2,
  "DepTime": 2,
  "ArrTime": 57,
  "DepTimeUTC": "2008-01-01T05:02:00.000Z",
  "ArrTimeUTC": "2008-01-01T05:57:00.000Z",
  "UniqueCarrier": "FL",
  "FlightNum": 579,
  "TailNum": "N937AT",
  "Distance": 259
}
```


The following commented-out code demonstrates how to load the flights dataset from ArangoDB. It needs to be run only once and this will create the flights, airport collections and also make a graph named FLIGHTS with flights collection as edges and airports as nodes.

In [10]:
# from arango_datasets import Datasets


In [11]:
# # Connect to the datasets interface
# datasets = Datasets(db)

# # List all available datasets
# print("Available datasets:")
# print(datasets.list_datasets())

# # Show information about the FLIGHTS dataset
# print("\nFlights dataset info:")
# print(datasets.dataset_info("FLIGHTS"))

# # Load the FLIGHTS dataset (assumed to be returned as a Pandas DataFrame)
# flights_data = datasets.load("FLIGHTS")



---
## Persist the Graph in ArangoDB

In this section, a connection to the ArangoDB database is made and FLIGHTS graph is retrieved. The graph is initially in a "lazy" state, meaning it fetches data on demand. To work more efficiently with the graph using NetworkX, we create a fully materialized copy in memory. This process ensures that all nodes and edges are loaded and available for our analysis.

<p align="center">
    <img src="https://raw.githubusercontent.com/arangodb/nx-arangodb/main/doc/_static/nxadb.png" style="height: 200px;">
    <img src="https://raw.githubusercontent.com/arangodb/nx-arangodb/main/doc/_static/dispatch.png" style="height: 200px;">
</p>

In [12]:
def materialize_graph(lazy_graph):
    """
    Create a materialized copy of a lazy graph.
    This function copies all nodes and edges from a lazy graph into a new NetworkX MultiDiGraph.
    """
    materialized = nx.MultiDiGraph()
    materialized.add_nodes_from(lazy_graph.nodes(data=True))
    materialized.add_edges_from(lazy_graph.edges(data=True))
    return materialized

In [13]:
# Create the lazy graph from ArangoDB and materialize it.
G_adb = nxadb.MultiDiGraph(name="FLIGHTS", db=db)
G_adb = materialize_graph(G_adb)
print("Graph materialized and loaded into memory.")

print(G_adb)

[22:03:34 -0400] [INFO]: Graph 'FLIGHTS' exists.
[22:03:34 -0400] [INFO]: Default node type set to 'airports'


Graph materialized and loaded into memory.
MultiDiGraph with 3375 nodes and 286463 edges


In [14]:
# Example: Print the degree of a specific airport node.
G_adb.degree("airports/IAH")

15270

## Building the Agentic App with LangChain & LangGraph

In this section, the focus is on creating an intelligent application that processes natural language queries related to the flight network. Key aspects of this section include:

#### **Agent Creation:**  
  An agent is developed to interpret and process natural language queries. It is equipped with multiple tools that convert queries either to AQL for direct database interrogation or to executable Python code for NetworkX-based graph analysis.

#### **Tool Integration:**  
  Tools are integrated to handle distinct tasks:
  - One tool translates natural language queries into AQL, executes them against the graph, and returns the results.
  - Another tool converts queries into Python code, which is executed using NetworkX algorithms to perform tasks like graph traversal and centrality calculations.
  - A hybrid approach is implemented to decompose complex queries into manageable sub-tasks, ensuring that each sub-task produces a concise output.

#### **Execution Flow:**  
  The process involves three primary steps:
  1. **Query Translation:** Natural language queries are transformed into either AQL or Python code.
  2. **Code Execution:** The generated code is executed on the materialized FLIGHTS graph to obtain relevant metrics and insights.
  3. **Result Synthesis:** The outputs are synthesized into a clear and concise final answer that addresses the original query.

#### **Frameworks and Libraries:**  
  The application leverages LangChain for language model interactions, LangGraph for graph-based functionalities, and NetworkX for in-memory graph analysis. Rich is also used to enhance console output with detailed step-by-step information.

This section demonstrates the integration of advanced language processing with graph analytics, providing the framework necessary for interactive and insightful analysis of the flight network.

<p align="center">
    <img src="https://raw.githubusercontent.com/SivaTSS/flight_graph_agent/main/images/pipeline.png" style="height: 500px;">
</p>

In [15]:
# Create the ArangoGraph LangChain wrapper
arango_graph = ArangoGraph(db)

In [16]:
arango_graph.schema

{'Graph Schema': [{'graph_name': 'FLIGHTS',
   'edge_definitions': [{'edge_collection': 'flights',
     'from_vertex_collections': ['airports'],
     'to_vertex_collections': ['airports']}]}],
 'Collection Schema': [{'collection_name': 'airports',
   'collection_type': 'document',
   'document_properties': [{'name': '_key', 'type': 'str'},
    {'name': '_id', 'type': 'str'},
    {'name': '_rev', 'type': 'str'},
    {'name': 'name', 'type': 'str'},
    {'name': 'city', 'type': 'str'},
    {'name': 'state', 'type': 'str'},
    {'name': 'country', 'type': 'str'},
    {'name': 'lat', 'type': 'float'},
    {'name': 'long', 'type': 'float'},
    {'name': 'vip', 'type': 'bool'}],
   'example_document': {'_key': '00M',
    '_id': 'airports/00M',
    '_rev': '_jVk9JKu---',
    'name': 'Thigpen ',
    'city': 'Bay Springs',
    'state': 'MS',
    'country': 'USA',
    'lat': 31.95376472,
    'long': -89.23450472,
    'vip': False}},
  {'collection_name': 'flights',
   'collection_type': 'edge',


In [17]:
# Set openai api key
import os
os.environ["OPENAI_API_KEY"] = key_data["openai_api_key"]

model_name="gpt-4o-mini"

In [18]:
from rich.console import Console
from rich.panel import Panel
from rich.rule import Rule
from rich.text import Text

# Initialize the Rich console for colored output
console = Console()

In [19]:
# Text to AQL Tool

@tool
def text_to_aql_to_text(query: str):
    """This tool is available to invoke the
    ArangoGraphQAChain object, which enables you to
    translate a Natural Language Query into AQL, execute
    the query, and translate the result back into Natural Language.
    """

    chain = ArangoGraphQAChain.from_llm(
    	llm=llm,
    	graph=arango_graph,
    	verbose=True,
        allow_dangerous_requests=True,
        max_aql_generation_attempts=1
    )
    
    result = chain.invoke(query)

    return str(result["result"])

In [20]:
# Text to NetworkX/cuGraph Tool with code corrector

@tool
def text_to_nx_algorithm_to_text(query):
    """This tool is available to invoke a NetworkX Algorithm on
    the ArangoDB Graph. You are responsible for accepting the
    Natural Language Query, establishing which algorithm needs to
    be executed, executing the algorithm, and translating the results back
    to Natural Language, with respect to the original query.

    If the query (e.g traversals, shortest path, etc.) can be solved using the Arango Query Language, then do not use
    this tool.
    """


    # --- Step 1: Generate NetworkX Code ---
    
    console.rule("[bold cyan]Step 1: Generating NetworkX Code[/bold cyan]")

    text_to_nx = llm.invoke(f"""
    I have a NetworkX Graph called `G_adb`. It has the following schema: {arango_graph.schema}

    I have the following graph analysis query: {query}.

    Generate the Python Code required to answer the query using the `G_adb` object.

    Be very precise on the NetworkX algorithm you select to answer this query. Think step by step.

    Only assume that networkx is installed, and other base python dependencies.

    Always set the last variable as `FINAL_RESULT`, which represents the answer to the original query.

    Only provide python code that I can directly execute via `exec()`. Do not provide any instructions.

    Make sure that `FINAL_RESULT` stores a short & concise answer. Avoid setting this variable to a long sequence.

    Your code:
    """).content

    text_to_nx_cleaned = re.sub(r"^```python\n|```$", "", text_to_nx, flags=re.MULTILINE).strip()
    
    console.rule("[bold green]Generated NetworkX Code[/bold green]")
    console.print(text_to_nx_cleaned, style="green")
    console.rule()

    
    # --- Step 2: Execute Generated Code with code corrector---

    console.rule("[bold cyan]Step 2: Executing NetworkX Code[/bold cyan]")
    global_vars = {"G_adb": G_adb, "nx": nx}
    local_vars = {}
    current_code = text_to_nx_cleaned
    success = False

    # Code corrector implementation
    attempt = 1
    MAX_ATTEMPTS = 3
        

    while attempt <= MAX_ATTEMPTS and not success:
        console.print(f"[bold blue]Attempt {attempt}: Executing code[/bold blue]")
        try:
            exec(current_code, global_vars, local_vars)
            success = True
        except Exception as e:
            error_message = str(e)
            print(f"EXEC ERROR on attempt {attempt}: {error_message}")

            correction_prompt = f"""
            I have the following Python code that is intended to operate on a NetworkX Graph called `G_adb` with the schema: {arango_graph.schema}

            The code was generated to answer the graph analysis query: {query}.

            However, when executing, it produced the following error: {error_message}

            Please correct the Python code to fix this error. Ensure that the final variable `FINAL_RESULT` contains a short and concise answer to the query.
            
            Only provide corrected Python code that can be directly executed via `exec()`, without any additional explanation.

            Make sure that `FINAL_RESULT` stores a short & concise answer. Avoid setting this variable to a long sequence.

            Your corrected code:
            """
            corrected_code = llm.invoke(correction_prompt).content
            current_code = re.sub(r"^```python\n|```$", "", corrected_code, flags=re.MULTILINE).strip()
            console.print("[bold yellow]Corrected Code:[/bold yellow]")
            console.print(current_code, style="yellow")
            attempt += 1
            local_vars = {}
            
    if not success:
        return f"EXEC ERROR after {MAX_ATTEMPTS} attempts: {error_message}"

    console.rule("[bold green]Final Execution Result[/bold green]")
    FINAL_RESULT = local_vars["FINAL_RESULT"]
    console.print(f"[bold magenta]FINAL_RESULT:[/bold magenta] {FINAL_RESULT}", style="magenta")
    console.rule()

    # --- Step 3: Generate Final Answer ---

    console.rule("[bold cyan]Step 3: Formulating Final Answer[/bold cyan]")

    nx_to_text = llm.invoke(f"""
        I have a NetworkX Graph called `G_adb`. It has the following schema: {arango_graph.schema}

        I have the following graph analysis query: {query}.

        I have executed the following python code to help me answer my query:

        ---
        {current_code}
        ---

        The `FINAL_RESULT` variable is set to the following: {FINAL_RESULT}.

        Based on my original Query and FINAL_RESULT, generate a short and concise response to
        answer my query.
        
        Your response:
    """).content
    console.rule("[bold green]Final Answer Generated[/bold green]")
    return nx_to_text

In [21]:
import yaml
def print_section(title, message=None):
    """
    Helper function to print a section header and an optional message using Rich.
    """
    header = f"[bold cyan]== {title} =="
    if message:
        panel = Panel(Text(message), title=header, expand=False)
    else:
        panel = Panel("", title=header, expand=False)
    console.print(panel)

def query_graph_taskwise(query):
    """
    Process a graph query task using the task-wise agent approach.
    """
    llm_task_wise = ChatOpenAI(temperature=0, model_name=model_name)
    tools = [text_to_aql_to_text, text_to_nx_algorithm_to_text]
    app = create_react_agent(llm_task_wise, tools)    
    final_state = app.invoke({"messages": [{"role": "user", "content": query}]})
    return final_state["messages"][-1].content


In [22]:
# Define a hybrid tool to split a complex query into manageable sub-tasks.
@tool
def text_to_hybrid_model_to_text(query: str):
    """
    This tool is available to answer a graph analysis query on the ArangoDB Graph using a hybrid approach.
    It splits the query into smaller, actionable sub-tasks when the original query is too complex for a 
    single AQL or code-based solution. Each sub-task must specify a limited output size (for example, top 10, top 20)
    and ensure that no sub-task returns more than 25 items.
    The results from sub-tasks are combined into a concise natural language answer.
    """
    
    print_section("Hybrid Mode Activated", f"Processing query: {query}")
    
    # Split the query into simpler sub-tasks
    split_prompt = f"""
    I have an ArangoDB Database with the following schema: {arango_graph.schema}
    The original query is: {query}
    This query is complex. Please split it into 2 or 3 high-level, actionable sub-tasks that need to be performed.
    IMPORTANT:
      - Do NOT include any AQL queries or code in your output—only describe the sub-tasks succinctly.
      - If a sub-task involves listing or ranking items (e.g., airlines), specify an output limit (for example, 'top 10', 'top 20').
      - Ensure that no sub-task is expected to return more than 25 items.
    Output your answer as a YAML list. For example:
    - Determine the top 10 airlines based on flight volume.
    - Compute a network metric (e.g., clustering coefficient) for these airlines, limiting results to the top 20.
    Your response:
    """
    sub_queries_yaml = llm.invoke(split_prompt).content.strip()
    sub_queries_yaml = re.sub(r"^```yaml\n|```$", "", sub_queries_yaml, flags=re.MULTILINE).strip()
    print_section("LLM Returned YAML", sub_queries_yaml)
    
    try:
        sub_queries = yaml.safe_load(sub_queries_yaml)
        if not isinstance(sub_queries, list):
            raise Exception("Parsed output is not a list.")
        print_section("Parsed Sub-Queries", str(sub_queries))
    except Exception as e:
        print_section("Error Parsing YAML", f"{e}\nUsing original query as single sub-task.")
        sub_queries = [query]
    
    branch_results = []
    for i, sub_query in enumerate(sub_queries):
        print_section(f"Processing Sub-Task {i+1}/{len(sub_queries)}", sub_query)
        
        # Include previously obtained results in the prompt context.
        if branch_results:
            context = "\n".join([f"Result {j+1}: {result}" for j, result in enumerate(branch_results)])
            context_prompt = f"Previously obtained results:\n{context}\n"
        else:
            context_prompt = ""
        
        # Prepare a cumulative query that incorporates previous results.
        cumulative_query = f"""
        {context_prompt}
        Now, process the following sub-task, taking into account the previous results:
        {sub_query}
        Provide an answer that builds on the previous information and ensure that any list output is limited (e.g., top 10, top 20, and not more than 25 items).
        """
        sub_result = query_graph_taskwise(cumulative_query)
        print_section(f"Result for Sub-Task {i+1}", sub_result)
        branch_results.append(sub_result)
    
    combine_prompt = f"""
    I have the following answers from sub-tasks:
    {branch_results}
    Combine these results into a concise final answer to the original query:
    {query}
    Provide only the final answer.
    """
    combined_result = llm.invoke(combine_prompt).content.strip()
    print_section("Combined Final Result", combined_result)
    return combined_result


In [23]:
# Create the Agentic Application
llm = ChatOpenAI(temperature=0, model_name=model_name)
tools = [text_to_aql_to_text, text_to_nx_algorithm_to_text, text_to_hybrid_model_to_text]

def query_graph(query):
    app = create_react_agent(llm, tools)    
    final_state = app.invoke({"messages": [{"role": "user", "content": query}]})
    return final_state["messages"][-1].content

---
### Sample Queries Execution 

In [24]:
query_graph("Identify the airline operating the most flights in the dataset.")



[1m> Entering new ArangoGraphQAChain chain...[0m
AQL Query (1):[32;1m[1;3m
WITH flights
FOR flight IN flights
    COLLECT airline = flight.UniqueCarrier WITH COUNT INTO count
    SORT count DESC
    LIMIT 10
    RETURN { airline: airline, count: count }
[0m
AQL Result:
[32;1m[1;3m[{'airline': 'WN', 'count': 48065}, {'airline': 'AA', 'count': 24797}, {'airline': 'OO', 'count': 22509}, {'airline': 'MQ', 'count': 19945}, {'airline': 'US', 'count': 18596}, {'airline': 'DL', 'count': 18107}, {'airline': 'UA', 'count': 17945}, {'airline': 'XE', 'count': 16764}, {'airline': 'NW', 'count': 15244}, {'airline': 'CO', 'count': 12398}][0m

[1m> Finished chain.[0m




[1m> Entering new ArangoGraphQAChain chain...[0m
AQL Query (1):[32;1m[1;3m
WITH airports, flights
FOR flight IN flights
    FILTER flight.UniqueCarrier == 'WN'
    RETURN { 
        FlightNum: flight.FlightNum, 
        Destination: flight._to 
    }
[0m
AQL Result:
[32;1m[1;3m[{'FlightNum': 680, 'Destination': 'airports/PBI'}, {'FlightNum': 3945, 'Destination': 'airports/BNA'}, {'FlightNum': 3558, 'Destination': 'airports/MCI'}, {'FlightNum': 106, 'Destination': 'airports/BWI'}, {'FlightNum': 1100, 'Destination': 'airports/TPA'}, {'FlightNum': 174, 'Destination': 'airports/DTW'}, {'FlightNum': 1663, 'Destination': 'airports/PHL'}, {'FlightNum': 2998, 'Destination': 'airports/PHL'}, {'FlightNum': 823, 'Destination': 'airports/PHX'}, {'FlightNum': 2449, 'Destination': 'airports/SAN'}][0m

[1m> Finished chain.[0m


'The airline operating the most flights in the dataset is **Southwest Airlines (WN)**, with a total of 48,065 flights.'

In [34]:
query_graph("Determine the airport with the highest betweenness centrality in the network.")

'The airport with the highest betweenness centrality in the network is Hartsfield-Jackson Atlanta International Airport (ATL).'

In [35]:
query_graph("Identify the top 5 airports by degree centrality.")

"The top 5 airports by degree centrality are:\n\n1. ATL (Hartsfield-Jackson Atlanta International Airport)\n2. ORD (O'Hare International Airport)\n3. DFW (Dallas/Fort Worth International Airport)\n4. DEN (Denver International Airport)\n5. LAX (Los Angeles International Airport)"

In [36]:
query_graph("Is the flight network strongly connected or weakly connected? Provide details.")

'The flight network is not connected, meaning it is neither strongly connected nor weakly connected. This indicates that there are segments of the network that are isolated from others, preventing complete connectivity among all nodes (airports).'

In [38]:
query_graph("Find the airport with the highest number of outgoing flights.")



[1m> Entering new ArangoGraphQAChain chain...[0m
AQL Query (1):[32;1m[1;3m
WITH airports
FOR airport IN airports SORT airport.outgoingFlights DESC LIMIT 1 RETURN airport
[0m
AQL Result:
[32;1m[1;3m[{'_key': 'ZZV', '_id': 'airports/ZZV', '_rev': '_jVk9J1C-_n', 'name': 'Zanesville Municipal', 'city': 'Zanesville', 'state': 'OH', 'country': 'USA', 'lat': 39.94445833, 'long': -81.89210528, 'vip': False}][0m

[1m> Finished chain.[0m


"The airport with the highest number of outgoing flights is Zanesville Municipal, located in Zanesville, Ohio, USA. This airport is identified by the key 'ZZV' and has geographical coordinates of approximately 39.94° latitude and -81.89° longitude. It is not classified as a VIP airport."

In [39]:
query_graph("Who is connected to Node 0?")



[1m> Entering new ArangoGraphQAChain chain...[0m
AQL Query (1):[32;1m[1;3m
WITH airports, flights
FOR v IN 1..1 OUTBOUND 'airports/0' flights RETURN v
[0m
AQL Result:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


'It seems that there are currently no connections from Node 0. This could mean that Node 0 does not have any outbound connections or that it may not exist in the database. If you have another node in mind or would like to explore different connections, please let me know!'

In [42]:
query_graph("How strongly connected is the network? Used connected components")

EXEC ERROR on attempt 1: not implemented for directed type


'The network has no strongly connected components, which means there are no subsets of nodes where every node is reachable from every other node within the directed graph.'

In [43]:
query_graph("Find the most influential airports using PageRank")

'The most influential airport based on PageRank analysis is Hartsfield-Jackson Atlanta International Airport (ATL).'

In [44]:
query_graph("For top 2 major airline, compute network metrics such as average path length for their flight network.")



[1m> Entering new ArangoGraphQAChain chain...[0m
AQL Query (1):[32;1m[1;3m
WITH flights
FOR airline IN (
    FOR flight IN flights
    COLLECT carrier = flight.UniqueCarrier WITH COUNT INTO flightCount
    RETURN { airline: carrier, count: flightCount }
)
SORT airline.count DESC
LIMIT 2
RETURN airline
[0m
AQL Result:
[32;1m[1;3m[{'airline': 'WN', 'count': 48065}, {'airline': 'AA', 'count': 24797}][0m

[1m> Finished chain.[0m


[1m> Entering new ArangoGraphQAChain chain...[0m


AQL Query (1):[32;1m[1;3m
WITH flights
FOR flight IN flights COLLECT airline = flight.UniqueCarrier WITH COUNT INTO totalFlights RETURN { airline: airline, totalFlights: totalFlights }
[0m
AQL Result:
[32;1m[1;3m[{'airline': '9E', 'totalFlights': 10493}, {'airline': 'AA', 'totalFlights': 24797}, {'airline': 'AQ', 'totalFlights': 1935}, {'airline': 'AS', 'totalFlights': 5993}, {'airline': 'B6', 'totalFlights': 8060}, {'airline': 'CO', 'totalFlights': 12398}, {'airline': 'DL', 'totalFlights': 18107}, {'airline': 'EV', 'totalFlights': 10762}, {'airline': 'F9', 'totalFlights': 3760}, {'airline': 'FL', 'totalFlights': 10148}][0m


[1m> Entering new ArangoGraphQAChain chain...[0m


[1m> Entering new ArangoGraphQAChain chain...[0m
AQL Query (1):[32;1m[1;3m
WITH airports, flights
FOR flight IN flights
    FILTER flight.UniqueCarrier IN ['WN', 'AA']
    LIMIT 10
    RETURN flight
[0m
AQL Result:
[32;1m[1;3m[{'_key': '306520635', '_id': 'flights/306520635', '_from': 'airports/MIA'


[1m> Finished chain.[0m


'The top 2 major airlines are:\n\n1. **American Airlines (AA)**: 24,797 flights\n2. **Delta Airlines (DL)**: 18,107 flights\n\nThe average path length for the flight network of these airlines is **478 miles**.'

In [45]:
query_graph("For the top two airlines based on flight volume, compute the centrality metrics")



[1m> Entering new ArangoGraphQAChain chain...[0m
AQL Query (1):[32;1m[1;3m
WITH flights
FOR airline IN (
    FOR flight IN flights
    COLLECT carrier = flight.UniqueCarrier WITH COUNT INTO flightCount
    RETURN { airline: carrier, count: flightCount }
)
SORT airline.count DESC
LIMIT 2
RETURN airline
[0m
AQL Result:
[32;1m[1;3m[{'airline': 'WN', 'count': 48065}, {'airline': 'AA', 'count': 24797}][0m

[1m> Finished chain.[0m


"The top two airlines based on flight volume are:\n\n1. **Southwest Airlines (WN)** with 48,065 flights\n   - **Highest Degree Centrality**: Las Vegas (LAS) - 0.857\n   - **Notable Airports**: Chicago Midway (MDW) - 0.746, Phoenix (PHX) - 0.667\n\n2. **American Airlines (AA)** with 24,797 flights\n   - **Highest Degree Centrality**: Dallas/Fort Worth (DFW) - 0.949\n   - **Notable Airports**: Chicago O'Hare (ORD) - 0.603, Miami (MIA) - 0.385\n\nThese metrics indicate the centrality of the airlines and their key airports in the network."

---
## Building UI using Gradio

In [25]:
# ! pip install seaborn

In [26]:
# ! pip install plotly

In [27]:
def text_to_visualizer_llm_to_plot(query: str, result: str):
    """
    This function generates a visualization for a graph query using only matplotlib.
    It receives the original query and its textual result, then asks the LLM to produce Python code 
    that creates a plot using matplotlib. The code should assign the resulting matplotlib figure 
    object to a variable named FINAL_PLOT. Do NOT include any saving or display commands in the generated code.
    
    After executing the generated code, the function will automatically save the figure to a hard-coded 
    filename "visualization_output.png" and then load and return that image.
    """
    console.rule("[bold cyan]Step 1: Generating Visualization Code (Matplotlib Only)[/bold cyan]")
    prompt = f"""
    I have a graph query: {query}
    and its result is: {result}.
    Generate Python code using matplotlib to produce a visualization plot relevant to the query and its result.
    Assign the resulting matplotlib figure object to a variable named FINAL_PLOT.
    Do NOT include any code to save or display the plot.
    Your code:
    """
    vis_code = llm.invoke(prompt).content
    vis_code_cleaned = re.sub(r"^```python\n|```$", "", vis_code, flags=re.MULTILINE).strip()
    console.rule("[bold green]Generated Visualization Code[/bold green]")
    console.print(vis_code_cleaned, style="green")
    console.rule("[bold cyan]Step 2: Executing Visualization Code[/bold cyan]")
    global_vars = {"G_adb": G_adb, "nx": nx, "plt": plt, "pd": pd}
    local_vars = {}
    current_code = vis_code_cleaned
    success = False
    attempt = 1
    MAX_ATTEMPTS = 3

    while attempt <= MAX_ATTEMPTS and not success:
        console.print(f"[bold blue]Attempt {attempt}: Executing visualization code[/bold blue]")
        try:
            exec(current_code, global_vars, local_vars)
            success = True
        except Exception as e:
            error_message = str(e)
            print(f"Visualization EXEC ERROR on attempt {attempt}: {error_message}")
            correction_prompt = f"""
            I have the following Python code intended to generate a visualization plot for a graph query using matplotlib.
            The query is: {query}
            The result of the query is: {result}
            The code produced the following error: {error_message}
            Please correct the code to fix this error. Ensure that the final variable FINAL_PLOT is assigned 
            to a matplotlib figure object. Do NOT include any code to save or display the plot.
            Your corrected code:
            """
            corrected_code = llm.invoke(correction_prompt).content
            current_code = re.sub(r"^```python\n|```$", "", corrected_code, flags=re.MULTILINE).strip()
            console.print("[bold yellow]Corrected Visualization Code:[/bold yellow]")
            console.print(current_code, style="yellow")
            attempt += 1
            local_vars = {}
    if not success:
        return None

    console.rule("[bold green]Final Visualization Code Executed[/bold green]")
    FINAL_PLOT = local_vars.get("FINAL_PLOT", None)
    if FINAL_PLOT is None:
        console.print("[bold red]No plot was generated.[/bold red]")
        return None

    # Saving of the plot to "visualization_output.png"
    try:
        output_filename = "visualization_output.png"
        FINAL_PLOT.savefig(output_filename, format="png")
        plt.close(FINAL_PLOT)
        abs_path = os.path.abspath(output_filename)
        print("Plot saved to:", abs_path)
        time.sleep(0.5)
        with open(output_filename, "rb") as f:
            img_bytes = f.read()
            
        vis_image = Image.open(io.BytesIO(img_bytes))
        print("Loaded image size:", vis_image.size, "mode:", vis_image.mode)
        return img_bytes
    except Exception as e:
        console.print(f"[bold red]Error saving or loading plot: {e}[/bold red]")
        return None


In [28]:
# !pip install --upgrade gradio

In [29]:
# !pip install ansi2html

In [30]:
import gradio as gr
import io
import sys
import time
import threading
import builtins
import re
from contextlib import redirect_stdout, redirect_stderr
from rich.console import Console
from ansi2html import Ansi2HTMLConverter
from PIL import Image

In [31]:
def run_query_generator_with_visualization(user_query, sample_query):
    """
    Executes the graph query and then calls the visualization function.
    Yields three outputs: final text answer, full logs, and the visualization (if generated).
    """
    final_query = user_query.strip() if user_query.strip() else sample_query
    captured_output = io.StringIO()

    original_stdout = sys.stdout
    original_stderr = sys.stderr
    original_print = builtins.print
    original_console = console

    new_console = Console(file=captured_output, force_terminal=True)

    def custom_print(*args, **kwargs):
        original_print(*args, **kwargs)
        message = " ".join(map(str, args)) + "\n"
        captured_output.write(message)

    builtins.print = custom_print
    sys.stdout = captured_output
    sys.stderr = captured_output
    globals()['console'] = new_console

    result_container = {}

    def target():
        try:
            result_container["result"] = query_graph(final_query)
        except Exception as e:
            result_container["result"] = f"Error: {str(e)}"

    thread = threading.Thread(target=target)
    thread.start()
    conv = Ansi2HTMLConverter(inline=True)

    while thread.is_alive():
        time.sleep(0.5)
        ansi_output = captured_output.getvalue()
        html_output = conv.convert(ansi_output, full=False)
        html_output_wrapped = f'<pre style="white-space: pre-wrap;">{html_output}</pre>'
        yield "", html_output_wrapped, None

    thread.join()
    final_logs = conv.convert(captured_output.getvalue(), full=False)
    final_logs_wrapped = f'<pre style="white-space: pre-wrap;">{final_logs}</pre>'
    final_result = result_container.get("result", "No result")

    sys.stdout = original_stdout
    sys.stderr = original_stderr
    globals()['console'] = original_console
    builtins.print = original_print

    try:
        vis_image = text_to_visualizer_llm_to_plot(final_query, final_result)
        if vis_image is not None:
            vis_image = Image.open(io.BytesIO(vis_image))
    except Exception as e:
        vis_image = None

    yield final_result, final_logs_wrapped, vis_image

In [32]:
# Sample queries.
sample_queries = [
    "List the top 10 popular airports",
    "Identify the airline operating the most flights in the dataset.",
    "Determine the airport with the highest betweenness centrality in the network.",
    "Identify the top 5 airports by degree centrality.",
    "Is the flight network strongly connected or weakly connected? Provide details.",
    "Find the airport with the highest number of outgoing flights.",
    "Who is connected to Node 0?",
    "How strongly connected is the network? Used connected components",
    "Find the most influential airports using PageRank",
    "For top 2 major airline, compute network metrics such as average path length for their flight network.",
    "For the top two airlines based on flight volume, compute the centrality metrics"
]


In [33]:
with gr.Blocks(title="Flight Graph Insights") as demo:
    gr.Markdown("## Flight Graph Insights: A Divide-and-Conquer Strategy")
    gr.Markdown(
        "On the top left, enter your graph query (or choose a sample). The top right displays the visualization output. "
        "Below, you'll see the final answer and the full internal thinking process."
    )
    with gr.Row():
        with gr.Column(scale=1):
            query_input = gr.Textbox(label="Graph Query", placeholder="Type your query here...", lines=3)
            sample_dropdown = gr.Dropdown(label="Sample Query", choices=sample_queries, value=sample_queries[0])
            run_button = gr.Button("Run Query")
        with gr.Column(scale=1):
            visualization_output = gr.Image(label="Visualization", type="pil")

    final_output = gr.Textbox(label="Final Answer", placeholder="The final answer will appear here...", lines=5)
    logs_output = gr.HTML(label="Full Thinking Process")
    
    run_button.click(
        fn=run_query_generator_with_visualization,
        inputs=[query_input, sample_dropdown],
        outputs=[final_output, logs_output, visualization_output]
    )

demo.launch()

* Running on local URL:  http://127.0.0.1:7880

To create a public link, set `share=True` in `launch()`.




Plot saved to: /home/siva/arango_hack/flight_graph_agent/notebooks/visualization_output.png
Loaded image size: (1200, 600) mode: RGBA
