This notebook includes the code for a custom agent that uses SQL tools, ReACT prompting strategies and an additional custom retriever tool. The LLM used in this file is ChatGPT-4o.

Install the required libraries:

In [1]:
pip install --upgrade --quiet langchain langchain-openai faiss-cpu

[0mNote: you may need to restart the kernel to use updated packages.


Import the necessary modules:

In [2]:
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

from langchain.agents import tool
from langchain.agents import AgentExecutor
from langchain.agents import create_react_agent
from langchain.agents import ZeroShotAgent
from langchain.agents import AgentType

from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser

from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages

from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_toolkits import create_retriever_tool

from langchain.agents.agent_toolkits.sql.prompt import SQL_FUNCTIONS_SUFFIX
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX
from langchain.agents.agent_toolkits.sql.prompt import SQL_SUFFIX

from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS

from langchain.sql_database import SQLDatabase

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.prompts import PromptTemplate

from langchain_community.vectorstores import FAISS
from langchain_community.agent_toolkits import create_sql_agent

import ast
import re

Fill in your OpenAI API key in the cell below to make sure your key is used as the environment variable.

In [3]:
#for gpt-4o:
%env OPENAI_API_KEY=

env: OPENAI_API_KEY=sk-proj-VYdLJEqGVWXF0GpsHCCqT3BlbkFJTn4vCvSeAUqQrCtyBh6v


Create the large language model that is going to be used:

In [4]:
llm = ChatOpenAI(model="gpt-4o", temperature=0)

Set up the database that is going to be used:

In [5]:
db = SQLDatabase.from_uri("sqlite:///wijkpaspoort.sqlite3")

Set up the SQL-tools for the agent to use, in this case the search_proper_nouns tool is also added to the tool_names variable:

In [6]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
tool_names = "search_proper_nouns, sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker"

Create the prompt that is sent to the agent, this time the agent is also instructed to check if the correct nouns are used:

In [7]:
template = """You are an agent designed to interact with a SQL database.
Given an input question, first check if the correct nouns are used, if not, use the most similar noun that you can find with the search_proper_nouns tool.
Then create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.
When there are less than 10 examples available in the database, limit the results to the number of examples that are available in the database.
You can order the results by a relevant column to return the most interesting examples in the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
ALWAYS come up with the final answer AFTER you have used the sql_db_query tool. 
DO NOT come up with the answer without constructing a SQL query.
If you have retrieved SQL-output in your observation and you are able to answer the input question with the SQL-output, your next thought may be that you know the final answer.
You MUST double check your query with the sql_db_query_checker tool before executing it. 
Make sure the input to the sql_db_query tool starts with 'SELECT' and does not start with something like: ' ```sql '.
Delete preceding backticks (```) and the 'sql' keyword when constructing a query.
Also make sure the SQL-query ends with a semicolon.
If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the final answer.

You have access to the following tools for interacting with the database.:
{tools}

Use exactly the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin!

Question: {input}
Thought:{agent_scratchpad}"""

react_prompt = PromptTemplate.from_template(template)

In [8]:
react_prompt.pretty_print()

You are an agent designed to interact with a SQL database.
Given an input question, check if the correct nouns are used, if not use the most similar noun that you can find.
Then create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.
When there are less than 10 examples available in the database, limit the results to the number of examples that are available in the database.
You can order the results by a relevant column to return the most interesting examples in the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
ALWAYS come up with the final answer AFTER you have used the sql_db_query tool. 
DO NOT come up with the answer without constructing a SQL query.
If you have retrieved SQL-output in your observation and you are able to answer the in

In the cell below, a function is created that can make a list of every element in a SQL-output.

In [9]:
def query_as_list(database, query):
    results = database.run(query)
    results = [element for sublist in ast.literal_eval(results) for element in sublist if element]
    results = [re.sub(r"\b\d+\b", "", element).strip() for element in results]
    return list(set(results))

With this function, lists can be created that contain all the proper nouns that are present in a column.

In [10]:
municipalities = query_as_list(db, "SELECT municipality_name FROM municipalities")
district = query_as_list(db, "SELECT district_name FROM districts")
neighborhoods = query_as_list(db, "SELECT neighborhood_name FROM neighborhoods")
sports = query_as_list(db, "SELECT sport FROM sports_facilities")
sports[0:10]

['softbal',
 'padel',
 'schieten',
 'kunstwielrijden',
 'outdoor recreatie',
 'atletiek',
 'vechtsport',
 'korfbal',
 'bmx',
 'kegelen']

A vector database is created with the lists from the cell above, the tool that we create is able to query that vector store each time the user includes a proper noun in their question.

In [11]:
vector_db = FAISS.from_texts(municipalities + district + neighborhoods + sports, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 5})

description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
valid proper nouns. Use the noun most similar to the search."""

retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)

Create a ReACT agent that uses the ReACT prompting technique with the create_react_agent function. The agent is also prompted with few shot examples. And it also includes the retriever_tool.

In [12]:
react_agent = create_react_agent(llm=llm,
                                 tools=tools + [retriever_tool],
                                 prompt=react_prompt)

An agent executor function is then used to get the agent model running:

In [13]:
agent_executor = AgentExecutor(agent=react_agent,
                               tools=tools + [retriever_tool],
                               handle_parsing_errors=True,
                               verbose=True,
                              )

In [14]:
agent_executor.invoke({"input": "How many measurements are there for the municipality of 'Amsterdam'?",
                    })



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mTo answer this question, I need to check the database schema to identify the relevant table and columns that store information about measurements and municipalities. 

First, I will list all the tables in the database to identify which table might contain the relevant data.

Action: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3maddresses, associated_fat, associated_liquid, clients, cluster_results, consumption_day, consumption_recipe, consumption_sup_nut, districts, household_income, measurements, municipalities, neighborhoods, participants, peoples_income, provinces, schools, sports_facilities, user_client, users[0m[32;1m[1;3mThe `measurements` and `municipalities` tables seem relevant to the question. I will inspect the schema of these tables to understand their structure and identify the columns that store information about measurements and municipalities.

Action: sql_db_schema
Action Input: measurements, mun

{'input': "How many measurements are there for the municipality of 'Amsterdam'?",
 'output': "There are 6 measurements for the municipality of 'Amsterdam'."}

In [15]:
llm.invoke("Find all measurements for the municipality of 'Amsterdam'")

AIMessage(content="To provide a comprehensive overview of the measurements for the municipality of Amsterdam, we can look at various aspects such as its geographical size, population, and other relevant statistics. Here are some key measurements:\n\n### Geographical Measurements\n1. **Area**:\n   - **Total Area**: Approximately 219.3 square kilometers (84.6 square miles).\n   - **Land Area**: Around 166.8 square kilometers (64.4 square miles).\n   - **Water Area**: About 52.5 square kilometers (20.3 square miles).\n\n### Population Measurements\n2. **Population**:\n   - **Total Population**: As of 2023, the population of Amsterdam is approximately 872,000 people.\n   - **Population Density**: Roughly 5,230 people per square kilometer (13,540 people per square mile).\n\n### Economic Measurements\n3. **Gross Domestic Product (GDP)**:\n   - **GDP**: Amsterdam is a significant economic hub, with a GDP of around €150 billion.\n\n### Infrastructure Measurements\n4. **Transportation**:\n   - 

In [16]:
questions = [
    {"input": "How many addresses are in the database?"},
    {"input": "How many school are there?"},
    {"input": "List all measurements for the municipality of ‘’s-Hertogenbosch’"},
    {"input": "How many indor services are there for korfball?"},
    {"input": "How many measurements are there in the municipality of Staphorst where the percentage of people with excellent to very good health is above 60%?"},
    {"input": "How many measurements are there for the Vondelburt district where more than 1 out of 10 people is a heavy drinker%?"},
    {"input": "Can you show me the top 10 municipalities by their average percentage of weekly sporters?"},
    {"input": "List the top 10 districts by overwait and also list their respective values?"},
    {"input": "Can you find the top 10 municipalities with the highest median standardised income for me? Also show me how many people in that municipality have overweight and how many have a high risk of anxiety or depression."},
    {"input": "Give me the top 20 neighbourhoods in the Netherlands based on the amount of people who had very high stress in the past 4 weeks and have more than 20 percent of heavy drinkers. Also show me in which municipalities these neighborhoods are located."},
    {"input": "Can you show me the municipalities with an average disposable income between 100 and 120 and also show me the respective percentage of weekly sporters in that municipality?"},
    {"input": "Make a list with the top 10 neighborhoods and their municipalities with, on average, the highest percentage of weekly sporters but where the average standardized income of the corresponding municipality is below 15."},
]

In [17]:
for i, question in enumerate(questions):
    try:
        ans = agent_executor.invoke(question)
        print("Question {} : {}".format(i+1, question))
        print("Answer : ", ans.get("output"))
        print()
    except Exception as e:
        print("Question {} : {}".format(i+1, question))
        print("Error : ", e)
        print()



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mTo determine the number of addresses in the database, I first need to identify the relevant table that stores address information. I will start by listing all the tables in the database.

Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3maddresses, associated_fat, associated_liquid, clients, cluster_results, consumption_day, consumption_recipe, consumption_sup_nut, districts, household_income, measurements, municipalities, neighborhoods, participants, peoples_income, provinces, schools, sports_facilities, user_client, users[0m[32;1m[1;3mThe table that likely contains address information is `addresses`. To determine the number of addresses in the database, I will count the entries in the `addresses` table.

Action: sql_db_query_checker
Action Input: SELECT COUNT(*) FROM addresses;[0m[36;1m[1;3m```sql
SELECT COUNT(*) FROM addresses;
```[0m[32;1m[1;3mThe query has been verified as correct. I will now execute