In [None]:
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import START, END, StateGraph
from langgraph.graph.message import MessagesState
from langgraph.prebuilt import ToolNode
from IPython.display import Image, display
from sqlalchemy import create_engine
from dotenv import load_dotenv
load_dotenv()

from sql_toolkit import (
  list_table_tool,
  get_table_schema_tool,
  execute_sql_tool
)

In [None]:
llm = ChatOpenAI(
  model="gpt-4o-mini",
  temperature=0.0,
  base_url="https://openai.vocareum.com/v1"
)

In [None]:
llm.invoke("What's Pokemon")

In [None]:
class State(MessagesState):
  user_query:str

workflow = StateGraph(State)

In [None]:
dba_tools = [list_table_tool, get_table_schema_tool, execute_sql_tool]

In [None]:
workflow.add_node("dba_tools", ToolNode(dba_tools))

In [None]:
dba_llm = llm.bind_tools(dba_tools, tool_choice="auto")

Agent Node

In [None]:
def messages_builder(state: State):
  dba_sys_msg = (
    "You are a Sr. SQL developer tasked with generating SQL queries. Perform the following steps:\n"
    "First, find out the appropriate table name based on all tables. "
    "Then get the table's schema to understand the columns. "
    "With the table name and the schema, generate the ANSI SQL query you think is applicable to the user question. "
    "Finally, use a tool to execute the above SQL query and output the result based on the user question."
  )
  messages = [
    SystemMessage(dba_sys_msg),
    HumanMessage(state["user_query"])
  ]
  return {"messages": messages}

In [None]:
def dba_agent(state: State):
  ai_message = dba_llm.invoke(state["messages"])
  ai_message.name = "dba_agent"
  return {"messages": ai_message}

In [None]:
workflow.add_node("messages_builder", messages_builder)
workflow.add_node("dba_agent", dba_agent)

Edges

In [None]:
def should_continue(state: State):
  messages = state["messages"]
  last_message = messages[-1]
  if last_message.tool_calls:
    return "dba_tools"
  return END

In [None]:
workflow.add_edge(START, "messages_builder")
workflow.add_edge("messages_builder", "dba_agent")
workflow.add_conditional_edges(
  source="dba_agent", 
  path=should_continue, 
  path_map=["dba_tools", END]
)
workflow.add_edge("dba_tools", "dba_agent")

In [None]:
react_graph = workflow.compile()

In [None]:
display(
  Image(
    react_graph.get_graph().draw_mermaid_png()
  )
)

In [None]:
db_engine = create_engine(f"sqlite:///sales.db")

In [None]:
config = {
  "configurable": {
    "db_engine": db_engine
  }
}

In [None]:
inputs = {
  "user_query": "How many Dell XPS 15 were sold?"  
}

In [None]:
messages = react_graph.invoke(
  input=inputs,
  config=config
)

In [None]:
for m in messages['messages']:
  m.pretty_print()