# Enhancing Text-to-SQL Agents with Step-by-Step Reasoning


[Enhancing Text-to-SQL Agents with Step-by-Step Reasoning](https://yia333.medium.com/implementing-reasoning-in-text-to-sql-agents-f979331176b4)


## SETUP


In [2]:
import os
from typing import Dict, Any
import re

from typing_extensions import TypedDict
from typing import Annotated, Optional
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

from sqlalchemy import create_engine

from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver

## LLM


In [3]:
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

In [4]:
llm_openai = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY)
llm_openai.invoke("what is 2+2").content

'2 + 2 equals 4.'

## DATA


In [5]:
db_file = "./chinook.db"
engine = create_engine(f"sqlite:///{db_file}")
db = SQLDatabase(engine=engine)

In [6]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm_openai)
sql_db_toolkit_tools = toolkit.get_tools()

## CREATE THE CHAIN


In [17]:
# define the prompt
query_gen_system = """
You are an SQL expert who helps analyze database queries. You have access to tools for interacting with the database. When given a question, you'll think through it carefully and explain your reasoning in natural language.

Then you'll walk through your analysis process:

1. First, you'll understand what tables and data you need
2. Then, you'll verify the schema and relationships
3. Finally, you'll construct an appropriate SQL query

For each query, you'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results

<reasoning>
You will **always** include this section before writing a query. Here, you will:
- Explain what information you need and why
- Describe your expected outcome
- Identify potential challenges
- Justify your query structure

If this section is missing, you will rewrite your response to include it.
</reasoning>

<analysis>
Here you break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>

<query>
The final SQL query
</query>

<error_check>
If there's an error, you'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>

<final_check>
Before finalizing, you will verify:
- Did you include a clear reasoning section?
- Did you explain your approach before querying?
- Did you provide an analysis of the query structure?
- If any of these are missing, you will revise your response.
</final_check>

Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reasoning for each tool call
"""

In [None]:
# create the prompt template
query_gen_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", query_gen_system),
        (MessagesPlaceholder(variable_name="messages")),
    ]
)

In [None]:
# create chain
query_gen_model = query_gen_prompt | llm_openai.bind_tools(tools=sql_db_toolkit_tools)

## GRAPH

In [27]:
class State(TypedDict):
    messages: Annotated[list, add_messages]


graph_builder = StateGraph(State)

In [28]:
def query_gen_node(state: State):
    return {"messages": [query_gen_model.invoke(state["messages"])]}


checkpointer = MemorySaver()
graph_builder.add_node("query_gen", query_gen_node)

<langgraph.graph.state.StateGraph at 0x1175c14f0>

In [29]:
query_gen_tools_node = ToolNode(tools=sql_db_toolkit_tools)
graph_builder.add_node("query_gen_tools", query_gen_tools_node)

<langgraph.graph.state.StateGraph at 0x1175c14f0>

In [30]:
graph_builder.add_conditional_edges(
    "query_gen", tools_condition, {"tools": "query_gen_tools", END: END}
)

graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)

In [31]:
def format_section(title: str, content: str) -> str:
    if not content:
        return ""
    return f"\n{content}\n"

In [32]:
def extract_section(text: str, section: str) -> str:
    pattern = f"<{section}>(.*?)</{section}>"
    match = re.search(pattern, text, re.DOTALL)
    return match.group(1).strip() if match else ""

In [33]:
def process_event(event: Dict[str, Any]) -> Optional[str]:
    if "query_gen" in event:
        messages = event["query_gen"]["messages"]
        for message in messages:
            content = message.content if hasattr(message, "content") else ""

            reasoning = extract_section(content, "reasoning")
            if reasoning:
                print(format_section("", reasoning))

            analysis = extract_section(content, "analysis")
            if analysis:
                print(format_section("", analysis))

            error_check = extract_section(content, "error_check")
            if error_check:
                print(format_section("", error_check))

            final_check = extract_section(content, "final_check")
            if final_check:
                print(format_section("", final_check))

            if hasattr(message, "tool_calls"):
                for tool_call in message.tool_calls:
                    tool_name = tool_call["name"]
                    if tool_name == "sql_db_query":
                        return tool_call["args"]["query"]

            query = extract_section(content, "query")
            if query:
                sql_match = re.search(r"```sql\n(.*?)\n```", query, re.DOTALL)
                if sql_match:
                    return format_section("", query)

    return None

In [34]:
def run_query(query_text: str):
    print(f"\nAnalyzing your question: {query_text}")
    final_sql = None

    for event in graph.stream(
        {"messages": [("user", query_text)]}, config={"configurable": {"thread_id": 12}}
    ):
        sql = process_event(event)
        if sql:
            final_sql = sql

    if final_sql:
        print(
            "\nBased on my analysis, here's the SQL query that will answer your question:"
        )
        print(f"\n{final_sql}")
        return final_sql

In [35]:
def interactive_sql():
    print("\nWelcome to the SQL Assistant! Type 'exit' to quit.")

    while True:
        try:
            query = input("\nWhat would you like to know? ")
            if query.lower() in ["exit", "quit"]:
                print("\nThank you for using SQL Assistant!")
                break

            run_query(query)

        except KeyboardInterrupt:
            print("\nThank you for using SQL Assistant!")
            break
        except Exception as e:
            print(f"\nAn error occurred: {str(e)}")
            print("Please try again with a different query.")

In [None]:
if __name__ == "__main__":
    interactive_sql()


Welcome to the SQL Assistant! Type 'exit' to quit.

Analyzing your question: what is this database about

Analyzing your question: how many tables are there

Analyzing your question: get me top 5 elements

Analyzing your question: get me top albums
