In [None]:
!pip install sqlalchemy fastapi langchain langchain_experimental langchain-groq pymysql langgraph

# Method 1 SQL Chain

In [None]:
import os
from fastapi import FastAPI
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
from langchain_groq import ChatGroq
from langchain.tools import Tool
from langchain_experimental.sql import SQLDatabaseChain
from langchain.sql_database import SQLDatabase
from dotenv import load_dotenv
import re

# Load environment variables
load_dotenv()

# SQLite Database Connection
DB_PATH = "subway.db"
engine = create_engine(f"sqlite:///{DB_PATH}")
db = SQLDatabase(engine)

# Initialize Groq model
llm = ChatGroq(
    model="deepseek-r1-distill-llama-70b",
    temperature=0,
    api_key=os.getenv("GROQ_API_KEY")
)

In [54]:
#  Define a database chain using the LLM
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

def safe_db_query(query: str):
    # Define the prompt template with the database schema.
    prompt_template = """
    You are an SQL expert with deep knowledge of SQLite.
    The database has a table named "outlets" with the following columns:
    - id (INTEGER, primary key)
    - name (TEXT, Subway outlet name)
    - address (TEXT, full address)
    - work_day_start (TEXT, work day start time)
    - work_day_end (TEXT, work day end time)
    - start_time (TEXT, start time)
    - end_time (TEXT, end time)
    - latitude (REAL)
    - longitude (REAL)

    **Guidelines for SQL Generation:**
    - Only use the listed columns. Do not reference any other tables.
    - Do not execute DROP, UPDATE, or DELETE statements—only read operations (SELECT) are allowed.
    - Always sanitize queries to avoid SQL injection.
    - Provide results only in JSON format, containing outlet details.

    User Query: {query}

    Please format your response as:

    THINKING: [your chain-of-thought analysis]
    SQL: [SQL query]
    ANSWER: [plain-language summary]
    """
    # Create the full prompt by inserting the user query into the template.
    full_prompt = prompt_template.format(query=query)
    
    try:
        # Invoke the chain using the full prompt.
        response = db_chain.invoke(full_prompt)
    except Exception as e:
        return f"Error invoking LLM: {str(e)}"
    
    # If the response is a dict, extract its "result" key.
    if isinstance(response, dict):
        response = response.get("result", "")
    
    try:
        # Define patterns to extract the SQL query from the response.
        sql_query = None
        patterns = [
            r"sql(.*?)",  # Standard code block formatting.
            r"\*\*(.*?)\*\*",   # Bold-wrapped SQL.
            r"SQL Query:\s*(.*)"       # Plain SQL declaration after "SQL Query:".
        ]

        # Try each pattern to find the SQL query.
        for pattern in patterns:
            match = re.search(pattern, response, re.DOTALL)
            if match:
                sql_query = match.group(1).strip()
                break

        if not sql_query:
            # If no SQL query is extracted, return the full chain response.
            return response

        # Cleanup: remove markdown formatting and extra whitespace.
        sql_query = re.sub(r"[*_<>\n]", ' ', sql_query)
        sql_query = re.sub(r'\s+', ' ', sql_query).strip()
        sql_query = sql_query.split(';')[0]  # Take only the first statement.

        # Validate that the SQL command is one of the allowed commands.
        valid_commands = ("SELECT", "INSERT", "UPDATE", "DELETE", "WITH", "EXPLAIN", "PRAGMA")
        if not sql_query.upper().startswith(valid_commands):
            return f"Invalid SQL query: {sql_query}"

        # Execute the SQL query.
        try:
            query_result = db.run(sql_query)
            return f"{response}\n\n**Query Result:**\n{query_result}"
        except Exception as e:
            return f"Database error: {str(e)}\nGenerated SQL: {sql_query}"
    
    except Exception as e:
        return f"Processing error: {str(e)}. Response: {response}"

# Define a tool for querying the database using our safe_db_query function.
db_tool = Tool(
    name="Database Query",
    func=safe_db_query,
    description="Returns the full model response including chain-of-thought and SQL query explanation."
)

def query_outlets(query_text: str):
    try:
        return db_tool.func(query_text)
    except Exception as e:
        return {'result': f"Error processing query: {str(e)}"}

result = query_outlets("Which outlets close the latest?")



[1m> Entering new SQLDatabaseChain chain...[0m

    You are an SQL expert with deep knowledge of SQLite.
    The database has a table named "outlets" with the following columns:
    - id (INTEGER, primary key)
    - name (TEXT, Subway outlet name)
    - address (TEXT, full address)
    - work_day_start (TEXT, work day start time)
    - work_day_end (TEXT, work day end time)
    - start_time (TEXT, start time)
    - end_time (TEXT, end time)
    - latitude (REAL)
    - longitude (REAL)

    **Guidelines for SQL Generation:**
    - Only use the listed columns. Do not reference any other tables.
    - Do not execute DROP, UPDATE, or DELETE statements—only read operations (SELECT) are allowed.
    - Always sanitize queries to avoid SQL injection.
    - Provide results only in JSON format, containing outlet details.

    User Query: Which outlets close the latest?

    Please format your response as:

    THINKING: [your chain-of-thought analysis]
    SQL: [SQL query]
    ANSWER: [plain

In [55]:
print(result)

<think>
Okay, so I need to figure out which Subway outlets close the latest. The user provided a table called "outlets" with several columns, including "end_time". My goal is to write an SQL query that retrieves the outlets with the latest closing times.

First, I should look at the columns available. The "end_time" column seems to hold the closing times, so I'll focus on that. Since the user wants the latest closing times, I should sort the results in descending order based on "end_time". That way, the outlets that close later will appear at the top.

I also need to remember to limit the results to at most five, as per the user's instructions. Using the LIMIT clause with 5 will ensure I don't get more than five results. 

I should make sure to only select the necessary columns: "id", "name", "address", and "end_time". This way, the query is efficient and only retrieves the information needed.

I need to be careful with the column names to ensure they match exactly what's in the table.

## Method 2: AI Agent to interact with SQL Database

In [40]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///subway.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM outlets LIMIT 5;")

sqlite
['outlets']


"[(1, 'Subway Menara UOA Bangsar', 'Jalan Bangsar Utama 1, Unit 1-2-G, Menara UOA Bangsar, Kuala Lumpur, 59000', 'Monday', 'Sunday', '08:00 AM', '08:00 PM', 3.126969, 101.6768848), (2, 'Subway Jln Pinang', 'G9, Wisma UOA II, 19, Jalan Pinang, Kuala Lumpur, 50450', 'Monday', 'Saturday', '08:00 AM', '09:00 PM', 3.1525875, 101.712256), (3, 'Subway UOA Damansara', 'Unit 50-G-5, Ground Floor, Wisma UOA Damansara, No. 50, Jalan Dungun, Kuala Lumpur, 50490', 'Monday', 'Saturday', '08:00 AM', '08:30 PM', 3.1517288, 101.6660061), (4, 'Subway Mont Kiara', 'E-01-16 ,Block E, Plaza Mont Kiara, 2 Jalan Kiara, Mont Kiara, Kuala Lumpur, 50480', 'Monday', 'Sunday', '10:00 AM', '10:00 PM', 3.1658129, 101.6510419), (5, 'Subway Avenue K', 'Lot UC-8 & 9, Upper Concourse Level, Avenue K, No. 156, Jalan Ampang, Kuala Lumpur, 50450', 'Monday', 'Sunday', '08:00 AM', '10:00 PM', 3.159418, 101.7134125)]"

In [41]:
from typing import Any

from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

In [42]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

print("List of tables: ", list_tables_tool.invoke(""))

print("Schema: ",get_schema_tool.invoke("outlets"))

List of tables:  outlets
Schema:  
CREATE TABLE outlets (
	id INTEGER, 
	name TEXT, 
	address TEXT, 
	work_day_start TEXT, 
	work_day_end TEXT, 
	start_time TEXT, 
	end_time TEXT, 
	latitude REAL, 
	longitude REAL, 
	PRIMARY KEY (id)
)

/*
3 rows from outlets table:
id	name	address	work_day_start	work_day_end	start_time	end_time	latitude	longitude
1	Subway Menara UOA Bangsar	Jalan Bangsar Utama 1, Unit 1-2-G, Menara UOA Bangsar, Kuala Lumpur, 59000	Monday	Sunday	08:00 AM	08:00 PM	3.126969	101.6768848
2	Subway Jln Pinang	G9, Wisma UOA II, 19, Jalan Pinang, Kuala Lumpur, 50450	Monday	Saturday	08:00 AM	09:00 PM	3.1525875	101.712256
3	Subway UOA Damansara	Unit 50-G-5, Ground Floor, Wisma UOA Damansara, No. 50, Jalan Dungun, Kuala Lumpur, 50490	Monday	Saturday	08:00 AM	08:30 PM	3.1517288	101.6660061
*/


In [43]:
from langchain_core.tools import tool

@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    
    # Convert raw result to more readable format
    try:
        # Assuming result is string like "[(name, closing_time), ...]"
        rows = eval(result)
        formatted = "\n".join([f"- {name}: {time}" for name, time in rows])
        return f"Query Results:\n{formatted}"
    except:
        return result


print(db_query_tool.invoke("SELECT * FROM outlets LIMIT 10;"))

[(1, 'Subway Menara UOA Bangsar', 'Jalan Bangsar Utama 1, Unit 1-2-G, Menara UOA Bangsar, Kuala Lumpur, 59000', 'Monday', 'Sunday', '08:00 AM', '08:00 PM', 3.126969, 101.6768848), (2, 'Subway Jln Pinang', 'G9, Wisma UOA II, 19, Jalan Pinang, Kuala Lumpur, 50450', 'Monday', 'Saturday', '08:00 AM', '09:00 PM', 3.1525875, 101.712256), (3, 'Subway UOA Damansara', 'Unit 50-G-5, Ground Floor, Wisma UOA Damansara, No. 50, Jalan Dungun, Kuala Lumpur, 50490', 'Monday', 'Saturday', '08:00 AM', '08:30 PM', 3.1517288, 101.6660061), (4, 'Subway Mont Kiara', 'E-01-16 ,Block E, Plaza Mont Kiara, 2 Jalan Kiara, Mont Kiara, Kuala Lumpur, 50480', 'Monday', 'Sunday', '10:00 AM', '10:00 PM', 3.1658129, 101.6510419), (5, 'Subway Avenue K', 'Lot UC-8 & 9, Upper Concourse Level, Avenue K, No. 156, Jalan Ampang, Kuala Lumpur, 50450', 'Monday', 'Sunday', '08:00 AM', '10:00 PM', 3.159418, 101.7134125), (6, 'Subway Berjaya Times Square', 'LG-08A, Berjaya Times Square, No. 1, Jalan Imbi, Kuala Lumpur, 55100', 'Mo

In [44]:
from langchain_core.prompts import ChatPromptTemplate

query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | llm.bind_tools(
    [db_query_tool], tool_choice="required"
)

query_check.invoke({"messages": [("user", "SELECT * FROM outlets LIMIT 10;")]})

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_qnw7', 'function': {'arguments': '{"query":"SELECT * FROM outlets LIMIT 10;"}', 'name': 'db_query_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 359, 'prompt_tokens': 303, 'total_tokens': 662, 'completion_time': 1.305454545, 'prompt_time': 0.01484716, 'queue_time': 0.237069632, 'total_time': 1.3203017049999999}, 'model_name': 'deepseek-r1-distill-llama-70b', 'system_fingerprint': 'fp_d7b20c4b1a', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-51666bd9-20ad-43ea-ae8e-72d59daa46d4-0', tool_calls=[{'name': 'db_query_tool', 'args': {'query': 'SELECT * FROM outlets LIMIT 10;'}, 'id': 'call_qnw7', 'type': 'tool_call'}], usage_metadata={'input_tokens': 303, 'output_tokens': 359, 'total_tokens': 662})

In [None]:
from typing import Annotated, Literal

from langchain_core.messages import AIMessage

from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages


# Define the state for the agent
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


# Define a new graph
workflow = StateGraph(State)


# Add a node for the first tool call
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }


def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}


workflow.add_node("first_tool_call", first_tool_call)

# Add nodes for the first two tools
workflow.add_node(
    "list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))

# Add a node for a model to choose the relevant tables based on the question and available tables
model_get_schema = llm.bind_tools(
    [get_schema_tool]
)
workflow.add_node(
    "model_get_schema",
    lambda state: {
        "messages": [model_get_schema.invoke(state["messages"])],
    },
)


# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")


# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are an expert in SQL with a strong attention to detail. Your role is to generate accurate and efficient SQLite queries based on the user's question and return a well-structured response.

Guidelines for Query Generation:
- Ensure SQL syntax correctness: Generate a valid SQLite query that answers the question precisely.
- Select only relevant columns: Avoid querying all columns unless explicitly required.
- Optimize results: If applicable, order results by a meaningful column to provide the most useful insights.
- Handle errors gracefully: If a query fails, modify and retry until a correct result is obtained.
- Avoid empty results: If the query returns an empty set, refine it to provide relevant data. However, do not fabricate information. If sufficient data is unavailable, state so clearly.
- Read-only operations: Do not modify the database (i.e., no INSERT, UPDATE, DELETE, DROP, etc.).

Response Formatting:
- Once the query returns results, analyze them and summarize the findings in clear, natural language.
- Example:
  - User question: "Which outlets close the latest?"
  - Response: "The outlets that close the latest are Outlet A at 11:00 PM and Outlet B at 11:30 PM."
- If insufficient data exists, state: "There is not enough information available to answer this question."
"""
query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | llm.bind_tools(
    [SubmitFinalAnswer]
)


def query_gen_node(state: State):
    message = query_gen.invoke(state)

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}


workflow.add_node("query_gen", query_gen_node)

# Add a node for the model to check the query before executing it
workflow.add_node("correct_query", model_check_query)

# Add node for executing the query
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))


# Define a conditional edge to decide whether to continue or end the workflow
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
    if getattr(last_message, "tool_calls", None):
        return END
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "correct_query"


# Specify the edges between the nodes
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges(
    "query_gen",
    should_continue,
)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")

# Compile the workflow into a runnable
app = workflow.compile()

In [53]:
messages = app.invoke(
    {"messages": [("user", "Which are the outlets that closes the latest?")]}
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
print(json_str)

The outlet that closes the latest is Subway Jln Pinang at 09:00 PM.


In [51]:
messages = app.invoke(
    {"messages": [("user", "How many outlets are located in Bangsar ")]}
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
print(json_str)

There is 1 outlet located in Bangsar.


Ref: https://langchain-ai.github.io/langgraph/tutorials/sql-agent/

In [None]:
# -------------------- Chatbot Workflow Setup -------------------- #

# Initialize SQLDatabase and Groq LLM model
db = SQLDatabase.from_uri(f"sqlite:///{DATABASE}")
llm = ChatGroq(
    model="deepseek-r1-distill-llama-70b",
    temperature=0,
    api_key=os.getenv("GROQ_API_KEY"),
)

# Create fallback handler for tool errors
def handle_tool_error(state: Dict[str, Any]) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )

# Setup the SQL Database Toolkit and extract specific tools
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

# Define a custom SQL query tool
@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    """
    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    try:
        rows = eval(result)
        formatted = "\n".join([f"- {name}: {time}" for name, time in rows])
        return f"Query Results:\n{formatted}"
    except Exception:
        return result

# Define prompt and tool for query check before execution
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""
query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | llm.bind_tools(
    [db_query_tool], tool_choice="required"
)

# Define the state for the agent
class State(TypedDict):
    messages: Annotated[List[AnyMessage], add_messages]

# Create the workflow state graph
workflow = StateGraph(State)

# Node: First tool call to list tables
def first_tool_call(state: State) -> dict[str, List[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}],
            )
        ]
    }

# Node: Model check to validate the query before execution
def model_check_query(state: State) -> dict[str, List[AIMessage]]:
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}

workflow.add_node("first_tool_call", first_tool_call)
workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))

# Node: Get schema using the model and schema tool
model_get_schema = llm.bind_tools([get_schema_tool])
workflow.add_node(
    "model_get_schema",
    lambda state: {"messages": [model_get_schema.invoke(state["messages"])]},
)

# Define the final answer submission tool structure
class SubmitFinalAnswer(BaseModel):
    final_answer: str = Field(..., description="The final answer to the user")

# Node: Generate query and final answer based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set. 
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
When you have the query results, analyze them and present the answer in a clear, natural language format. For example:
'The outlets that close the latest are: Outlet A at 11:00 PM, Outlet B at 11:30 PM.'

Always use this format for final answers."""
query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])

def query_gen_node(state: State):
    message = query_gen.invoke(state)
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
                        tool_call_id=tc["id"],
                    )
                )
    return {"messages": [message] + tool_messages}

workflow.add_node("query_gen", query_gen_node)
workflow.add_node("correct_query", model_check_query)
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))

def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]: # type: ignore
    messages = state["messages"]
    last_message = messages[-1]
    if getattr(last_message, "tool_calls", None):
        return END
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "correct_query"

workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges("query_gen", should_continue)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")

# Compile the workflow into a runnable.
workflow_app = workflow.compile()

# ---------------------------- FastAPI Endpoint ---------------------------- #

class QueryRequest(BaseModel):
    query: str

@app.post("/query", response_model=str, summary="Execute a SQL query via LangGraph workflow")
def run_query(request: QueryRequest):
    """
    API endpoint to execute a SQL query via LangGraph workflow.
    The workflow will process the input question, generate and validate the query,
    execute it, and return the final answer.
    """
    try:
        state = {"messages": [("user", request.query)]}
        result_state = workflow_app.invoke(state)
        # Directly extract the final answer string from the tool call.
        final_answer = result_state["messages"][-1].tool_calls[0]["args"]["final_answer"]
        if not final_answer:
            raise ValueError("Final answer missing in the tool call.")
        return final_answer
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

In [85]:
db = SQLDatabase.from_uri(f"sqlite:///subway.db")
db.get_table_info()

'\nCREATE TABLE outlets (\n\tid INTEGER, \n\tname TEXT, \n\taddress TEXT, \n\twork_day_start TEXT, \n\twork_day_end TEXT, \n\tstart_time TEXT, \n\tend_time TEXT, \n\tlatitude REAL, \n\tlongitude REAL, \n\tPRIMARY KEY (id)\n)\n\n/*\n3 rows from outlets table:\nid\tname\taddress\twork_day_start\twork_day_end\tstart_time\tend_time\tlatitude\tlongitude\n1\tSubway Menara UOA Bangsar\tJalan Bangsar Utama 1, Unit 1-2-G, Menara UOA Bangsar, Kuala Lumpur, 59000\tMonday\tSunday\t08:00 AM\t08:00 PM\t3.126969\t101.6768848\n2\tSubway Jln Pinang\tG9, Wisma UOA II, 19, Jalan Pinang, Kuala Lumpur, 50450\tMonday\tSaturday\t08:00 AM\t09:00 PM\t3.1525875\t101.712256\n3\tSubway UOA Damansara\tUnit 50-G-5, Ground Floor, Wisma UOA Damansara, No. 50, Jalan Dungun, Kuala Lumpur, 50490\tMonday\tSaturday\t08:00 AM\t08:30 PM\t3.1517288\t101.6660061\n*/'