# An agent for interacting with a SQL database

In this tutorial, we will walk through how to build an agent that can answer questions about a SQL database. 

At a high level, the agent will:
1. Fetch the available tables from the database
2. Decide which tables are relevant to the question
3. Fetch the DDL for the relevant tables
4. Generate a query based on the question and information from the DDL
5. Double-check the query for common mistakes using an LLM
6. Execute the query and return the results
7. Correct mistakes surfaced by the database engine until the query is successful
8. Formulate a response based on the results

## Setup

Before we begin, please make sure you have setup the `.env` file in the project 
directory as described in [`README.md`](README.md).

Next, we will load in the necessary environment variables (e.g., API keys) for this notebook:

In [1]:
import os
from dotenv import load_dotenv

_ = load_dotenv()

assert os.environ.get("GOOGLE_API_KEY")

## Configure the database

We will be creating a SQLite database for this tutorial. SQLite is a lightweight database that is easy to set up and use. We will be loading the `Chinook` database, which is a sample database that represents a digital media store.
Find more information about the database [here](https://www.sqlitetutorial.net/sqlite-sample-database/).

We will use a handy SQL database wrapper available in the `langchain_community` package to interact with the database. The wrapper provides a simple interface to execute SQL queries and fetch results.

In [2]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///db/Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

## Utility functions

We will define a few utility functions to help us with the agent implementation. Specifically, we will wrap a `ToolNode` with a fallback to handle errors and surface them to the agent.

In [3]:
from typing import Any

from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool
from langchain_core.runnables import (
    RunnableLambda, 
    RunnableWithFallbacks
)
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(
    tools: list[BaseTool]
) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback 
    to handle errors and surface them 
    to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        fallbacks=[
            RunnableLambda(handle_tool_error)
        ],
        exception_key="error"
    )


def handle_tool_error(
    state: dict
) -> dict:
    error = state.get("error")
    tool_call_list = state["messages"][-1].tool_calls
    
    return {
        "messages": [
            ToolMessage(
                content=(
                    f"Error: {repr(error)}\n"
                    "Please fix your mistakes."
                ),
                tool_call_id=tool_call["id"],
            )
            for tool_call in tool_call_list
        ]
    }

## Define tools for the agent

We will define a few tools that the agent will use to interact with the database.

1. `list_tables_tool`: Fetch the available tables from the database
2. `get_schema_tool`: Fetch the DDL for a table
3. `db_query_tool`: Execute the query and fetch the results OR return an error message if the query fails

For the first two tools, we will grab them from the `SQLDatabaseToolkit`, also available in the `langchain_community` package.

In [5]:
from langchain_google_genai import ChatGoogleGenerativeAI


llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash"
)

In [19]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit


toolkit = SQLDatabaseToolkit(
    db=db, 
    llm=llm
)
tools = toolkit.get_tools()

[(tool.name, tool.description) for tool in tools]

[('sql_db_query',
  "Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields."),
 ('sql_db_schema',
  'Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3'),
 ('sql_db_list_tables',
  'Input is an empty string, output is a comma-separated list of tables in the database.'),
 ('sql_db_query_checker',
  'Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!')]

In [20]:
list_tables_tool = next(
    tool 
    for tool in tools 
    if tool.name == "sql_db_list_tables"
)

get_schema_tool = next(
    tool 
    for tool in tools 
    if tool.name == "sql_db_schema"
)

In [21]:
print(list_tables_tool.invoke(""))

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track


In [22]:
print(get_schema_tool.invoke("Artist"))


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


The third will be defined manually. For the `db_query_tool`, we will execute the query against the database and return the results.

In [27]:
from langchain_core.tools import tool


@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, 
    check the query, and try again.
    """
    result = db.run_no_throw(query)
    
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
        
    return result

In [26]:
print(
    db_query_tool.invoke(
        "SELECT * FROM Artist LIMIT 10;"
    )
)

[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


While not strictly a tool, we will prompt an LLM to check for common mistakes in the query and later add this as a node in the workflow.

In [30]:
from langchain_core.prompts import ChatPromptTemplate


query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    messages=[
        ("system", query_check_system), 
        ("placeholder", "{messages}")
    ]
)
query_check = (
    query_check_prompt 
    | llm.bind_tools(
        tools=[db_query_tool], 
        tool_choice="any",
    )
)

Key 'title' is not supported in schema, ignoring
Key 'title' is not supported in schema, ignoring


In [31]:
query_check.invoke(
    input={
        "messages": [
            ("user", "SELECT * FROM Artist LIMIT 10;")
        ]
    }
)

AIMessage(content='', additional_kwargs={'function_call': {'name': 'db_query_tool', 'arguments': '{"query": "SELECT * FROM Artist LIMIT 10;"}'}}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-19a02cce-3366-4608-b13f-a249196a0c49-0', tool_calls=[{'name': 'db_query_tool', 'args': {'query': 'SELECT * FROM Artist LIMIT 10;'}, 'id': 'd64e26fe-cd07-47ca-8a44-1fc74b114b05', 'type': 'tool_call'}], usage_metadata={'input_tokens': 223, 'output_tokens': 25, 'total_tokens': 248})