In [2]:
from langchain_ollama import ChatOllama
from langgraph_supervisor import create_supervisor
from langgraph.prebuilt import create_react_agent
from sql import run_sql_workflow
from langchain_ollama import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage

local_model = ChatOllama(model="qwen3:8b", temperature=0.0)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:

# --- Tool: SQL tables tool ---
def query_cases_table(query: str) -> str:
    """Query the 'cases' table using a SQL workflow."""
    schema= """
    === CASES ===

    You are a SQL assistant with access to the `cases` table. Use the structure below to understand how to build and execute SQL queries.

    Structure:
            column_name           data_type
            ------------          ----------
                    id              VARCHAR
            order_date        TIMESTAMP_NS
            employee_id           VARCHAR
                branch           VARCHAR
                supplier           VARCHAR
                avg_time             DOUBLE
    estimated_delivery        TIMESTAMP_NS
                delivery        TIMESTAMP_NS
                on_time            BOOLEAN
                in_full            BOOLEAN
        number_of_items            INTEGER
                ft_items            INTEGER
            total_price             DOUBLE
        total_activities            INTEGER
    rework_activities            INTEGER
    automatic_activities            INTEGER

    Instructions:
    - Only use columns from the schema above.
    - Use standard SQL syntax.
    - For date filtering, use the `order_date` column.
    - For delivery performance analysis, refer to `on_time`, `in_full`, and `delivery`.
    - To calculate totals or averages, use aggregation functions like `SUM`, `AVG`, `COUNT`, etc.
    - Use `WHERE`, `GROUP BY`, and `ORDER BY` clauses as needed.
    - Never assume values; always use only what can be reasonably queried from the table structure.

    Input:
    - You will receive a user query in natural language.
    - Convert it into a valid SQL query using the above structure.
    - Return the executed result.

    Example Input:
    "How many deliveries were on time in March 2024?"

    Example Output:
    "SELECT COUNT(*) FROM cases WHERE on_time = TRUE AND order_date >= '2024-03-01' AND order_date < '2024-04-01';"
    """
    print(f"query: {query}")
    state= run_sql_workflow(query,schema)
    print()
    return state["result"]

sql_agent = create_react_agent(
    model=local_model,
    tools=[query_cases_table],  # You can expand this to include multiple table tools
    name="sql_agent",
    prompt="You are a SQL expert. Use the table tools to query and return results. Do not analyze the output."
)
    

In [None]:
# Define a test function to simulate the agent's responses.
def test_sql_agent():
    # List of test questions
    test_questions = [
        "How many deliveries were on time in March 2024?",  # Example for the cases table
        "How many rework activities were performed by human users?",  # Example for the activities table
        "What is the most common process variant?",  # Example for the variants table
        "How many cases had delivery delays?"  # Example for the cases table
    ]

    # Simulate a conversation with the sql_agent
    for question in test_questions:
        print(f"Testing question: {question}")
        messages = [HumanMessage(content=question)]
        messages = sql_agent.invoke({"messages": messages})
        for m in messages["messages"]:
            m.pretty_print()

# Run the test function
test_sql_agent()

Testing question: How many deliveries were on time in March 2024?
Converting question to SQL What is the average time taken to process an order?
Generated SQL query: SELECT AVG(avg_time) FROM cases;
🚀 Executing query: SELECT AVG(avg_time) FROM cases;
SQL SELECT query executed successfully.
SQL query results: avg(avg_time): 944237.0942919466
SQL error states: False
Token count: 26


In [None]:
messages = [HumanMessage(content="How many deliveries were on time in March 2024?")]
messages = sql_agent.invoke({"messages": messages})
for m in messages["messages"]:
    m.pretty_print()

Converting question to SQL What is the average time taken for order fulfillment?
Generated SQL query: SELECT AVG(estimated_delivery - order_date) AS avg_time_taken FROM cases;
🚀 Executing query: SELECT AVG(estimated_delivery - order_date) AS avg_time_taken FROM cases;
Error executing SQL query: Binder Error: No function matches the given name and argument types 'avg(INTERVAL)'. You might need to add explicit type casts.
	Candidate functions:
	avg(DECIMAL) -> DECIMAL
	avg(SMALLINT) -> DOUBLE
	avg(INTEGER) -> DOUBLE
	avg(BIGINT) -> DOUBLE
	avg(HUGEINT) -> DOUBLE
	avg(DOUBLE) -> DOUBLE


LINE 1: SELECT AVG(estimated_delivery - order_date) AS avg_time_taken FROM...
               ^
SQL query results: Error executing SQL query: Binder Error: No function matches the given name and argument types 'avg(INTERVAL)'. You might need to add explicit type casts.
	Candidate functions:
	avg(DECIMAL) -> DECIMAL
	avg(SMALLINT) -> DOUBLE
	avg(INTEGER) -> DOUBLE
	avg(BIGINT) -> DOUBLE
	avg(HUGEINT) -> DOU

In [3]:
def analyze_data(task: str, context: str) -> str:
    """
    Perform analysis or reasoning based on SQL query output or other structured context.
    
    Parameters:
    - task: A natural language instruction, such as "Identify the trend" or "Summarize the results."
    - context: Output from a SQL query or other relevant data.

    Returns:
    - A string containing the analyzed insight or summary.
    """
    
    # Define the prompt
    prompt = ChatPromptTemplate.from_messages([
        ("system", 
         "You are a data analyst. Your job is to interpret and summarize results from SQL queries or structured outputs.\n"
         "Use the context provided to complete the task given by the user."),
        ("human", 
         "Context:\n{context}\n\nTask:\n{task}\n\nProvide a concise and helpful answer:")
    ])
    
    # Load LLM
    llm = OllamaLLM(model="mistral:latest", temperature=0.0)
    
    # Create chain and run
    chain = prompt | llm | StrOutputParser()
    result = chain.invoke({"task": task, "context": context})
    return result.strip()

def analyze_data_tool(input: dict) -> str:
    """
    Wrapper tool for the analysis agent. Expects a dictionary with 'task' and 'context'.
    """
    task = input.get("task", "")
    context = input.get("context", "")
    return analyze_data(task, context)


analysis_agent = create_react_agent(
    model=local_model,
    tools=[analyze_data_tool],
    name="analysis_agent",
    prompt=(
        "You are an analysis expert. Use the 'analyze_data_tool' to interpret and summarize "
        "results from SQL outputs. Do not generate or modify SQL queries."
    )
)




In [4]:
workflow = create_supervisor(
    [sql_agent],
    model=local_model,
    prompt=(
        "You are a supervisor managing a SQL expert.\n"
        "Use 'sql_agent' for any questions requiring database queries, the sql agent.\n"
        "If an analysis task is needed you can perform it based on the sql results.\n"
        "Always decompose complex queries into subtasks and assign them accordingly."
    )
)

# --- Run the app ---
app = workflow.compile()

In [None]:
# Compile and run
app = workflow.compile()
result = app.invoke({
    "messages": [
        {
            "role": "user",
            "content": "What is the total number of cases and how many of them were delivered late? Can you analyze what that says about delivery performance?"
        }
    ]
})

query: SELECT COUNT(*) as total, SUM(CASE WHEN delivery_status = 'Late' THEN 1 ELSE 0 END) as late FROM cases
query: SELECT COUNT(*) as total, SUM(CASE WHEN delivery_status = 'Late' THEN 1 ELSE 0 END) as late FROM cases
Converting question to SQL SELECT COUNT(*) as total, SUM(CASE WHEN delivery_status = 'Late' THEN 1 ELSE 0 END) as late FROM cases
Converting question to SQL SELECT COUNT(*) as total, SUM(CASE WHEN delivery_status = 'Late' THEN 1 ELSE 0 END) as late FROM cases
Generated SQL query: SELECT COUNT(*) AS total, SUM(CASE WHEN on_time = FALSE THEN 1 ELSE 0 END) AS late FROM cases
🚀 Executing query: SELECT COUNT(*) AS total, SUM(CASE WHEN on_time = FALSE THEN 1 ELSE 0 END) AS late FROM cases
SQL SELECT query executed successfully.
SQL query results: total: 991, late: 100
SQL error states: False
Token count: 14
Generated SQL query: SELECT COUNT(*) AS total, SUM(CASE WHEN on_time = FALSE THEN 1 ELSE 0 END) AS late FROM cases
🚀 Executing query: SELECT COUNT(*) AS total, SUM(CASE WH

In [20]:
for m in result["messages"]:
    m.pretty_print()



What is the total number of cases and how many of them were delivered late? Can you analyze what that says about delivery performance?
Name: sql_agent

 Based on the provided SQL query results:

- The total number of cases is 0.
- There are no late deliveries.

This indicates that there were no cases processed in this dataset. Therefore, I was not able to process any specific question regarding case details or delivery statuses due to the absence of data.
Name: sql_agent

Transferring back to supervisor
Tool Calls:
  transfer_back_to_supervisor (abdf33c2-121f-4319-bce9-c52c12f1600f)
 Call ID: abdf33c2-121f-4319-bce9-c52c12f1600f
  Args:
Name: transfer_back_to_supervisor

Successfully transferred back to supervisor
Name: sql_agent

 Based on the provided data:

- The total number of cases is 991.
- Out of these, there are 100 cases with a delivery status marked as "Late."

This suggests that approximately 10.09% (or 1 in every 10 cases) experienced late deliveries. To analyze the deliv