This notebook includes the code for a custom agent that uses SQL tools. It only uses basic prompting. The LLM used in this file is ChatGPT-3.5-turbo.

Install the required libraries:

In [1]:
pip install --upgrade --quiet langchain langchain-community langchain-openai

[0mNote: you may need to restart the kernel to use updated packages.


Import the necessary modules:

In [2]:
from langchain_openai import ChatOpenAI

from langchain.agents import AgentExecutor

from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser

from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages

from langchain.agents.agent_toolkits import SQLDatabaseToolkit

from langchain.sql_database import SQLDatabase

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder

Fill in your OpenAI API key in the cell below to make sure your key is used as the environment variable.

In [3]:
#for gpt-3.5-turbo:
%env OPENAI_API_KEY=

env: OPENAI_API_KEY=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


Create the large language model that is going to be used:

In [4]:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

Set up the database that is going to be used:

In [5]:
db = SQLDatabase.from_uri("sqlite:///wijkpaspoort.sqlite3")

Set up the SQL-tools for the agent to use:

In [6]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
tool_names = "sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker"

Bind the sql-tools to the llm:

In [7]:
llm_with_sql_tools = llm.bind_tools(tools)

Create the prompt that is sent to the agent:

In [8]:
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", """You are an agent designed to interact with a SQL database.
Given an input question create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.
When there are less than 10 examples available in the database, limit the results to the number of examples that are available in the database.
You can order the results by a relevant column to return the most interesting examples in the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
ALWAYS come up with the final answer AFTER you have used the sql_db_query tool. 
DO NOT come up with the answer without constructing a SQL query.
If you have retrieved SQL-output in your observation and you are able to answer the input question with the SQL-output, your next thought may be that you know the final answer.
You MUST double check your query with the sql_db_query_checker tool before executing it. 
Make sure the input to the sql_db_query tool starts with 'SELECT' and does not start with something like: ' ```sql '.
Delete preceding backticks (```) and the 'sql' keyword when constructing a query.
Also make sure the SQL-query ends with a semicolon.
If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the final answer.

You have access to the following tools for interacting with the database:
"sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker"

Begin!""",),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

In [9]:
prompt.pretty_print()


You are an agent designed to interact with a SQL database.
Given an input question create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.
When there are less than 10 examples available in the database, limit the results to the number of examples that are available in the database.
You can order the results by a relevant column to return the most interesting examples in the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
ALWAYS come up with the final answer AFTER you have used the sql_db_query tool. 
DO NOT come up with the answer without constructing a SQL query.
If you have retrieved SQL-output in your observation and you are able to answer the input question with the SQL-output, your next thought may be that you know the final answer.
You

In the cell below, an agent chain is created that incorporates the user input into the prompt, which is then sent to the llm with the SQL-tools. The output that is generated is then converted into natural language.

In [10]:
agent = (
    {
        "input": lambda x: x["input"],
        "agent_scratchpad": lambda x: format_to_openai_tool_messages(x["intermediate_steps"]),
    }
    | prompt
    | llm_with_sql_tools
    | OpenAIToolsAgentOutputParser()
)

An agent executor function is then used to get the agent model running:

In [11]:
agent_executor = AgentExecutor(agent=agent, 
                               tools=tools,
                               handle_parsing_errors=True,
                               return_intermediate_steps=True,
                               verbose=True)

In [12]:
agent_executor.invoke({"input": "How many measurements are there?"})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM measurements;'}`


[0m[36;1m[1;3m[(104046,)][0m[32;1m[1;3mThere are 104,046 measurements in the database.[0m

[1m> Finished chain.[0m


{'input': 'How many measurements are there?',
 'output': 'There are 104,046 measurements in the database.',
 'intermediate_steps': [(ToolAgentAction(tool='sql_db_query', tool_input={'query': 'SELECT COUNT(*) FROM measurements;'}, log="\nInvoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM measurements;'}`\n\n\n", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_9utWYf63RJDRYsrMsy9KcnQa', 'function': {'arguments': '{"query":"SELECT COUNT(*) FROM measurements;"}', 'name': 'sql_db_query'}, 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_calls'}, id='run-61590323-43fd-41c4-a5ae-1929d61698da', tool_calls=[{'name': 'sql_db_query', 'args': {'query': 'SELECT COUNT(*) FROM measurements;'}, 'id': 'call_9utWYf63RJDRYsrMsy9KcnQa'}], tool_call_chunks=[{'name': 'sql_db_query', 'args': '{"query":"SELECT COUNT(*) FROM measurements;"}', 'id': 'call_9utWYf63RJDRYsrMsy9KcnQa', 'index': 0}])], tool_call_id='call_9utWYf63RJDRYsrMs

In [13]:
llm.invoke("How many measurements are there?")

AIMessage(content='There is not enough information provided to determine the number of measurements. Please provide more context or details.', response_metadata={'token_usage': {'completion_tokens': 20, 'prompt_tokens': 13, 'total_tokens': 33}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-2e0227e9-ea46-451f-a83e-46455010885d-0', usage_metadata={'input_tokens': 13, 'output_tokens': 20, 'total_tokens': 33})

In [14]:
questions = [
    {"input": "How many addresses are in the database?"},
    {"input": "How many school are there?"},
    {"input": "List all measurements for the municipality of ‘’s-Hertogenbosch’"},
    {"input": "How many indor services are there for korfball?"},
    {"input": "How many measurements are there in the municipality of Staphorst where the percentage of people with excellent to very good health is above 60%?"},
    {"input": "How many measurements are there for the Vondelburt district where more than 1 out of 10 people is a heavy drinker%?"},
    {"input": "Can you show me the top 10 municipalities by their average percentage of weekly sporters?"},
    {"input": "List the top 10 districts by overwait and also list their respective values?"},
    {"input": "Can you find the top 10 municipalities with the highest median standardised income for me? Also show me how many people in that municipality have overweight and how many have a high risk of anxiety or depression."},
    {"input": "Give me the top 20 neighbourhoods in the Netherlands based on the amount of people who had very high stress in the past 4 weeks and have more than 20 percent of heavy drinkers. Also show me in which municipalities these neighborhoods are located."},
    {"input": "Can you show me the municipalities with an average disposable income between 100 and 120 and also show me the respective percentage of weekly sporters in that municipality?"},
    {"input": "Make a list with the top 10 neighborhoods and their municipalities with, on average, the highest percentage of weekly sporters but where the average standardized income of the corresponding municipality is below 15."},
]

In [15]:
for i, question in enumerate(questions):
    try:
        ans = agent_executor.invoke(question)
        print("Question {} : {}".format(i+1, question))
        print("Answer : ", ans.get("output"))
        print()
    except Exception as e:
        print("Question {} : {}".format(i+1, question))
        print("Error : ", e)
        print()



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM addresses;'}`


[0m[36;1m[1;3m[(7888477,)][0m[32;1m[1;3mThere are 7,888,477 addresses in the database.[0m

[1m> Finished chain.[0m
Question 1 : {'input': 'How many addresses are in the database?'}
Answer :  There are 7,888,477 addresses in the database.



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(DISTINCT school_id) AS total_schools FROM schools;'}`


[0m[36;1m[1;3mError: (sqlite3.OperationalError) no such column: school_id
[SQL: SELECT COUNT(DISTINCT school_id) AS total_schools FROM schools;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)[0m[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3maddresses, associated_fat, associated_liquid, clients, cluster_results, consumption_day, consumption_recipe, consumption_sup_nut, districts, househo