In [1]:
from dotenv import load_dotenv
load_dotenv()

from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")

In [2]:
# Setup Model for response validation
import regex as re
import pandas as pd
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Literal

class FinalResponse(BaseModel):
    original_query: str = Field(description="The user's original natural language query.")
    output: str = Field(description="The result of the executed query.")
    charts: Optional[List[Dict]] = Field(default=None, description="List of charts generated, if any.")

class SupervisorResponse(BaseModel):
    next_action: Literal["FINISH", "EXECUTE_QUERY", "Schema_Query_Agent"]
    final_response: Optional[FinalResponse] = None

class SchemaQueryResponse(BaseModel):
    final_query: str = Field(description="The constructed single-line Pandas query.")

In [3]:
from langchain_core.output_parsers import JsonOutputParser
schema_query_parser = JsonOutputParser(pydantic_object=SchemaQueryResponse)

# Helper functions for Supervisor Agent
def extract_query_from_state(state):
    for message in reversed(state['messages']):
        if message.name == "Schema_Query_Agent":
            pattern = r"The constructed Pandas query is: (.*)"
            match = re.search(pattern, message.content)
            return match.group(1)
        else:
            return "Query Not generated"

from Agent_tools import data_df as df
from langchain_experimental.tools import PythonAstREPLTool
python_repl_tool = PythonAstREPLTool(locals={'df':df})
def execute_query(query):
    if query is None:
        return f"Query is empty"
    try:
        result = python_repl_tool.run(query)
    except Exception as e:
        return f"Error executing query: {e}"
    
    if isinstance(result, (pd.DataFrame, pd.Series)):
        return result.to_string()
    else:
        return str(result)

from langchain.output_parsers import PydanticOutputParser
supervisor_parser = PydanticOutputParser(pydantic_object=SupervisorResponse)
def validate_response(response):
    try:
        final_response = SupervisorResponse.model_validate(response)
        return_answer = final_response.final_response
        return return_answer.json() #converting to string
    except Exception as e:  
        return f"Error parsing final response: {e}"

In [4]:
#Create Supervisor Agent
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from Agent_prompts import get_supervisor_prompt

members = ["Schema_Query_Agent"]
supervisor_prompt = get_supervisor_prompt()
options = ["FINISH", "EXECUTE_QUERY", "Schema_Query_Agent"]

final_supervisor_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", supervisor_prompt),
        MessagesPlaceholder(variable_name="messages"),
        (
            "system",
            "Given the conversation above, what should be the next action? "
            "Select one of: {options}",
        ),
    ]
).partial(options=str(options), members=", ".join(members), format_instructions=supervisor_parser.get_format_instructions())

def supervisor_agent(state):
    supervisor_chain = final_supervisor_prompt | llm.with_structured_output(SupervisorResponse)
    response = supervisor_chain.invoke(state)
    next_action = response.next_action

    if next_action=='EXECUTE_QUERY':
        query=extract_query_from_state(state)
        result=execute_query(query)
        answer_eq = f"The answer after executing query: {query} is {result}"
        state['messages'].append(HumanMessage(content=answer_eq, name="EXECUTE_QUERY"))
        state['next'] = "supervisor"
    elif next_action == "Schema_Query_Agent":
        state['next'] = "Schema_Query_Agent"
    elif next_action == "FINISH":
        final_response = validate_response(response)
        if final_response:
            state['messages'].append(HumanMessage(content=final_response, name="Supervisor"))
        else:
            state['messages'].append(HumanMessage(content="No final response provided.", name="Supervisor"))
        state['next'] = "FINISH"
    else:
        state['messages'].append(HumanMessage(content="Unexpected error", name="Supervisor"))
        state['next'] = "FINISH"
    return state

In [5]:
#Helper utility to create agent nodes in graph
def agent_node(state, agent, name):
    result = agent.invoke(state)
    agent_message = result["messages"][-1].content
    if schema_query_parser:
        try:
            output = schema_query_parser.parse(agent_message)
        except Exception as e:
            output = f"Error parsing agent output: {e}"
    else:
        output = agent_message

    answer = f"The constructed Pandas query is: {output.get("final_query")}"

    return {
        "messages": [HumanMessage(content=answer, name=name)]
    }

In [6]:
#Create Worker Agents
import functools
from langgraph.prebuilt import create_react_agent
from Agent_prompts import get_schema_query_prompt

#schema info and query builder agent
from langgraph.checkpoint.memory import MemorySaver
from langchain.tools.render import render_text_description
from Agent_tools import get_dataset_info_tool, get_dataset_indexing_structure

memory=MemorySaver()
schema_query_tools = [get_dataset_info_tool, get_dataset_indexing_structure]

# Create Tools description
tools_list = render_text_description(list(schema_query_tools))
tool_names = ", ".join((t.name for t in schema_query_tools))
format_instructions = schema_query_parser.get_format_instructions()
SQ_agent_prompt = get_schema_query_prompt(tools_list, tool_names, format_instructions)

schema_query_agent = create_react_agent(llm, tools=schema_query_tools, state_modifier=SQ_agent_prompt, checkpointer=memory)
schema_query_node = functools.partial(agent_node, agent=schema_query_agent, name="Schema_Query_Agent")

In [7]:
#Create langgraph flow
import operator
from typing import Sequence, Annotated
from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph import END, StateGraph, START

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next: str

workflow = StateGraph(AgentState)
workflow.add_node("Schema_Query_Agent", schema_query_node)
workflow.add_node("supervisor", supervisor_agent)

for member in members:
    workflow.add_edge(member, "supervisor")

conditional_map = {
    "Schema_Query_Agent": "Schema_Query_Agent",
    "EXECUTE_QUERY": "supervisor",  # Loop back to supervisor after execution
    "supervisor": "supervisor",
    "FINISH": END
}

workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
workflow.add_edge(START, "supervisor")
graph = workflow.compile()

In [8]:
# from IPython.display import Image, display
# display(Image(graph.get_graph(xray=True).draw_mermaid_png()))

In [9]:
# for s in graph.stream(
#     {
#         "messages": [
#             HumanMessage(content="What is the percentage of male population in North West of UK County?")
#         ]
#     }
# ):
#     if "__end__" not in s:
#         print(s)
#         print("----")

{'supervisor': {'messages': [HumanMessage(content='What is the percentage of male population in North West of UK County?', additional_kwargs={}, response_metadata={})], 'next': 'Schema_Query_Agent'}}
----
{'Schema_Query_Agent': {'messages': [HumanMessage(content="The constructed Pandas query is: df.loc[('Gender', 'Male', 'Percentage'), ('UK County', 'North West')]", additional_kwargs={}, response_metadata={}, name='Schema_Query_Agent')]}}
----
{'supervisor': {'messages': [HumanMessage(content='What is the percentage of male population in North West of UK County?', additional_kwargs={}, response_metadata={}, id='572cf82c-732e-43f2-bab1-c93ed2b2b7a1'), HumanMessage(content='What is the percentage of male population in North West of UK County?', additional_kwargs={}, response_metadata={}, id='572cf82c-732e-43f2-bab1-c93ed2b2b7a1'), HumanMessage(content="The constructed Pandas query is: df.loc[('Gender', 'Male', 'Percentage'), ('UK County', 'North West')]", additional_kwargs={}, response_m

In [14]:
final_state = graph.invoke(
    {
        "messages": [
            HumanMessage(content="What is the total count of males across all regions in the UK County?")
        ]
    }
)

In [15]:
print(final_state["messages"][-1].content)

{"original_query":"What is the total count of males across all regions in the UK County?","output":"1431.0","charts":null}
