# SQL Agent - OpenAI
- https://python.langchain.com/docs/use_cases/sql/agents/

In [1]:
import re
import ast
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)

print('Load .env variables:', load_dotenv())

Load .env variables: True


# Load Database

Setup:

- Download [``Chinook_Sqlite.sql``](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql)
- Run ``sqlite3 Chinook.db``
- Run ``.read Chinook_Sqlite.sql``
- Test ``SELECT * FROM Artist LIMIT 10``;

In [2]:
db = SQLDatabase.from_uri('sqlite:///Chinook.db')
print(db.dialect)
print(db.get_usable_table_names())
db.run('SELECT * FROM Artist LIMIT 10;')

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

# Prepare Agent

In [3]:
llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=0)
agent_executor = create_sql_agent(llm, db=db,
                                  agent_type='openai-tools',
                                  verbose=True)

# Execute

In [4]:
agent_executor.invoke(
    "List the total sales per country. Which country's customers spent the most?"
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}`


[0m[33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1

{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The total sales per country are as follows:\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n\nThe country whose customers spent the most is the USA with a total sales amount of $523.06.'}

# Using a dynamic few-shot prompt

In [5]:
examples = [
    {"input": "List all artists.", "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of invoices.",
        "query": "SELECT COUNT(*) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "Which albums are from the year 2000?",
        "query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=['input'],
)

In [6]:
system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix="",
)

In [7]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

In [8]:
# Example formatted prompt
prompt_val = full_prompt.invoke(
    {
        "input": "How many arists are there",
        "top_k": 5,
        "dialect": "SQLite",
        "agent_scratchpad": [],
    }
)
print(prompt_val.to_string())

System: You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't 

In [9]:
agent = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
)

In [10]:
agent.invoke({"input": "How many artists are there?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist;'}`


[0m[36;1m[1;3m[(275,)][0m[32;1m[1;3mThere are 275 artists in the database.[0m

[1m> Finished chain.[0m


{'input': 'How many artists are there?',
 'output': 'There are 275 artists in the database.'}

# Dealing with high-cardinality columns

In [11]:
def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return list(set(res))

artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
albums[:5]

['For Those About To Rock We Salute You',
 'Blizzard of Ozz',
 'The Colour And The Shape',
 'B-Sides -',
 'War']

In [12]:
vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 5})
description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
valid proper nouns. Use the noun most similar to the search."""
retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)

In [13]:
system = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool! 

You have access to the following tables: {table_names}

If the question does not seem related to the database, just return "I don't know" as the answer."""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{input}"), MessagesPlaceholder("agent_scratchpad")]
)
agent = create_sql_agent(
    llm=llm,
    db=db,
    extra_tools=[retriever_tool],
    prompt=prompt,
    agent_type="openai-tools",
    verbose=True,
)

In [14]:
agent.invoke({"input": "How many albums does alis in chain have?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `search_proper_nouns` with `{'query': 'alis in chain'}`


[0m[36;1m[1;3mAlice In Chains

Aisha Duo

Xis

Da Lama Ao Caos

A-Sides[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': "SELECT COUNT(*) AS album_count FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')"}`


[0m[36;1m[1;3m[(1,)][0m[32;1m[1;3mAlice In Chains has 1 album.[0m

[1m> Finished chain.[0m


{'input': 'How many albums does alis in chain have?',
 'output': 'Alice In Chains has 1 album.'}

In [15]:
agent.invoke({"input": "How many albums does Tom Koopman have?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `search_proper_nouns` with `{'query': 'Tom Koopman'}`


[0m[36;1m[1;3mTon Koopman

Tim Maia

Wilhelm Kempff

Gene Krupa

Adrian Leaper & Doreen de Feis[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': "SELECT COUNT(*) AS album_count FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'Ton Koopman')"}`


[0m[36;1m[1;3m[(1,)][0m[32;1m[1;3mTom Koopman has 1 album.[0m

[1m> Finished chain.[0m


{'input': 'How many albums does Tom Koopman have?',
 'output': 'Tom Koopman has 1 album.'}

# Final Query

In [16]:
agent.invoke(
    "List the total sales per country. Which country's customers spent the most?"
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT c.Country, SUM(i.Total) AS Total_Sales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.Country\nORDER BY Total_Sales DESC;'}`


[0m[36;1m[1;3m[('USA', 523.06), ('Canada', 303.96), ('France', 195.1), ('Brazil', 190.1), ('Germany', 156.48), ('United Kingdom', 112.86), ('Czech Republic', 90.24), ('Portugal', 77.24), ('India', 75.26), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.62), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.62), ('Spain', 37.62), ('Poland', 37.62), ('Italy', 37.62), ('Denmark', 37.62), ('Belgium', 37.62), ('Australia', 37.62), ('Argentina', 37.62)][0m[32;1m[1;3mThe total sales per country are as follows:
1. USA: $523.06
2. Canada: $303.96
3. France: $195.10
4. Brazil: $190.10
5. Germany: $156.48

The country whose customers spent the most is the USA with a total 

{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The total sales per country are as follows:\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n\nThe country whose customers spent the most is the USA with a total sales amount of $523.06.'}

# Extra

In [19]:
agent.invoke(
    "Which is the table of the schema with more tuples?"
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) AS num_rows, tbl_name FROM (SELECT * FROM Album UNION ALL SELECT * FROM Artist UNION ALL SELECT * FROM Customer UNION ALL SELECT * FROM Employee UNION ALL SELECT * FROM Genre UNION ALL SELECT * FROM Invoice UNION ALL SELECT * FROM InvoiceLine UNION ALL SELECT * FROM MediaType UNION ALL SELECT * FROM Playlist UNION ALL SELECT * FROM PlaylistTrack UNION ALL SELECT * FROM Track) GROUP BY tbl_name ORDER BY num_rows DESC LIMIT 1;'}`


[0m[36;1m[1;3mError: (sqlite3.OperationalError) SELECTs to the left and right of UNION ALL do not have the same number of result columns
[SQL: SELECT COUNT(*) AS num_rows, tbl_name FROM (SELECT * FROM Album UNION ALL SELECT * FROM Artist UNION ALL SELECT * FROM Customer UNION ALL SELECT * FROM Employee UNION ALL SELECT * FROM Genre UNION ALL SELECT * FROM Invoice UNION ALL SELECT * FROM InvoiceLine UNION ALL SELECT * FROM MediaType UN

{'input': 'Which is the table of the schema with more tuples?',
 'output': 'The table with the most tuples in the schema is PlaylistTrack, which has 8715 rows.'}

In [21]:
agent.invoke(
    "Provide the quantity of tuples per table in the database."
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': "SELECT 'Album' AS Table_Name, COUNT(*) AS Tuple_Count FROM Album UNION ALL SELECT 'Artist' AS Table_Name, COUNT(*) AS Tuple_Count FROM Artist UNION ALL SELECT 'Customer' AS Table_Name, COUNT(*) AS Tuple_Count FROM Customer UNION ALL SELECT 'Employee' AS Table_Name, COUNT(*) AS Tuple_Count FROM Employee UNION ALL SELECT 'Genre' AS Table_Name, COUNT(*) AS Tuple_Count FROM Genre UNION ALL SELECT 'Invoice' AS Table_Name, COUNT(*) AS Tuple_Count FROM Invoice UNION ALL SELECT 'InvoiceLine' AS Table_Name, COUNT(*) AS Tuple_Count FROM InvoiceLine UNION ALL SELECT 'MediaType' AS Table_Name, COUNT(*) AS Tuple_Count FROM MediaType UNION ALL SELECT 'Playlist' AS Table_Name, COUNT(*) AS Tuple_Count FROM Playlist UNION ALL SELECT 'PlaylistTrack' AS Table_Name, COUNT(*) AS Tuple_Count FROM PlaylistTrack UNION ALL SELECT 'Track' AS Table_Name, COUNT(*) AS Tuple_Count FROM Track"}`


[0m[36;1m

{'input': 'Provide the quantity of tuples per table in the database.',
 'output': 'Here are the quantities of tuples per table in the database:\n\n- Album: 347 tuples\n- Artist: 275 tuples\n- Customer: 59 tuples\n- Employee: 8 tuples\n- Genre: 25 tuples\n- Invoice: 412 tuples\n- InvoiceLine: 2240 tuples\n- MediaType: 5 tuples\n- Playlist: 18 tuples\n- PlaylistTrack: 8715 tuples\n- Track: 3503 tuples'}