1. Data
2. LLM
3. Chain
4. Agent

## Data

In [1]:
from sqlite3 import connect
from langchain_community.utilities import SQLDatabase

In [2]:
db = SQLDatabase.from_uri("sqlite:///chinook.db", sample_rows_in_table_info = 3)
print(len(db.get_usable_table_names()), db.get_usable_table_names())

11 ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [3]:
print(db.table_info)


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

In [4]:
con = connect("chinook.db")

In [5]:
cur = con.execute("Select Count(Distinct(AlbumId)) from Album;")

In [6]:
print(cur.fetchall())

[(347,)]


## LLM

In [7]:
from langchain.callbacks import get_openai_callback
from langchain_openai import ChatOpenAI
import os

In [8]:
llm = ChatOpenAI(model="gpt-3.5-turbo", 
                 temperature=0, 
                 openai_api_key=os.environ.get("OPENAI_API_KEY"))

In [9]:
llm

ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x0000029D7BEC5C40>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x0000029D7BED8400>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_proxy='')

## Chain

In [10]:
from langchain_experimental.sql.base import SQLDatabaseChain

In [11]:
sql_db_chain = SQLDatabaseChain.from_llm(llm = llm, db = db, verbose = True)

In [12]:
with get_openai_callback() as cb:
    result = sql_db_chain.invoke("Which country's customers spent the most?")
print("LLM Result: ", result['result'])
print(cb)



[1m> Entering new SQLDatabaseChain chain...[0m
Which country's customers spent the most?
SQLQuery:[32;1m[1;3mSELECT c."Country", SUM(i."Total") AS total_spent
FROM "Customer" c
JOIN "Invoice" i ON c."CustomerId" = i."CustomerId"
GROUP BY c."Country"
ORDER BY total_spent DESC
LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('USA', 523.0600000000003)][0m
Answer:[32;1m[1;3mThe country whose customers spent the most is USA, with a total amount spent of $523.06.[0m
[1m> Finished chain.[0m
LLM Result:  The country whose customers spent the most is USA, with a total amount spent of $523.06.
Tokens Used: 4884
	Prompt Tokens: 4810
	Completion Tokens: 74
Successful Requests: 2
Total Cost (USD): $0.007363000000000001


In [13]:
with get_openai_callback() as cb:
    result = sql_db_chain.invoke("List the total sales per country. Which country's customers spent the most?")
print("LLM Result: ", result['result'])
print(cb)



[1m> Entering new SQLDatabaseChain chain...[0m
List the total sales per country. Which country's customers spent the most?
SQLQuery:[32;1m[1;3mSELECT c."Country", SUM(i."Total") AS "TotalSales"
FROM "Customer" c
JOIN "Invoice" i ON c."CustomerId" = i."CustomerId"
GROUP BY c."Country"
ORDER BY "TotalSales" DESC
LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('USA', 523.0600000000003)][0m
Answer:[32;1m[1;3mThe country whose customers spent the most is USA with a total sales of $523.06.[0m
[1m> Finished chain.[0m
LLM Result:  The country whose customers spent the most is USA with a total sales of $523.06.
Tokens Used: 4898
	Prompt Tokens: 4825
	Completion Tokens: 73
Successful Requests: 2
Total Cost (USD): $0.0073834999999999994


In [14]:
print(sql_db_chain.llm_chain.prompt.template)

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: 

## Agent

Advantages of SQL Agents:

1. It will save tokens by only retrieving the schema from relevant tables.
2. It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.
3. chain runs on predefined pattern while agent is defined by llm and llm dynamically do reasoning to find out which tool needs to be run

In [15]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain.agents import AgentType

In [16]:
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True, stream_runnable=False)

In [17]:
with get_openai_callback() as cb:
    result = agent_executor.invoke("Which country's customers spent the most?")
print("LLM Result: ", result['output'])
print(cb)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}`


[0m[33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	

In [18]:
with get_openai_callback() as cb:
    result = agent_executor.invoke("List the total sales per country. Which country's customers spent the most?")
print("LLM Result: ", result['output'])
print(cb)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}`


[0m[33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	

In [19]:
agent_executor

AgentExecutor(name='SQL Agent Executor', verbose=True, agent=RunnableMultiActionAgent(runnable=RunnableAssign(mapper={
  agent_scratchpad: RunnableLambda(lambda x: format_to_openai_tool_messages(x['intermediate_steps']))
})
| ChatPromptTemplate(input_variables=['agent_scratchpad', 'input'], input_types={'agent_scratchpad': typing.List[typing.Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}, messages=[SystemMessage(content='You are an agent designed to interact with a SQL database.\nGiven an input question, create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.\nUnless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.\nYou can order the results by

In [20]:
print(agent_executor.agent.runnable.get_prompts()[0].messages[0].content)

You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't k

In [21]:
agent_executor.agent.runnable.to_json()['kwargs']['middle'][1].kwargs['tools']

[{'type': 'function',
  'function': {'name': 'sql_db_query',
   'description': "Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.",
   'parameters': {'type': 'object',
    'properties': {'query': {'description': 'A detailed and correct SQL query.',
      'type': 'string'}},
    'required': ['query']}}},
 {'type': 'function',
  'function': {'name': 'sql_db_schema',
   'description': 'Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3',
   'parameters': {'type': 'object',
    'properties': {'table_names': {'descript

In [22]:
react_agent_executor = create_sql_agent(llm, db=db, agent_type="zero-shot-react-description", verbose=True, stream_runnable=False)

In [23]:
with get_openai_callback() as cb:
    result = react_agent_executor.invoke("Which country's customers spent the most?")
print("LLM Result: ", result['output'])
print(cb)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mWe need to find the country with the highest total spending by customers.
Action: sql_db_query_checker
Action Input: SELECT country, SUM(amount) AS total_spent FROM customers GROUP BY country ORDER BY total_spent DESC LIMIT 1;[0m[36;1m[1;3mSELECT country, SUM(amount) AS total_spent 
FROM customers 
GROUP BY country 
ORDER BY total_spent DESC 
LIMIT 1;[0m[32;1m[1;3mThe query looks correct, I will now execute it to get the final answer.
Action: sql_db_query
Action Input: SELECT country, SUM(amount) AS total_spent FROM customers GROUP BY country ORDER BY total_spent DESC LIMIT 1;[0m[36;1m[1;3mError: (sqlite3.OperationalError) no such table: customers
[SQL: SELECT country, SUM(amount) AS total_spent FROM customers GROUP BY country ORDER BY total_spent DESC LIMIT 1;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)[0m[32;1m[1;3mI need to check the list of tables in the database to see if the 'customers