## SQL Database

In [1]:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///chinook.db", sample_rows_in_table_info = 3)
print(len(db.get_usable_table_names()), db.get_usable_table_names())

11 ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [2]:
print(db.table_info)


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

## Language Model

In [3]:
from langchain_openai import ChatOpenAI
import os

llm = ChatOpenAI(model="gpt-3.5-turbo", 
                 temperature=0, 
                 openai_api_key=os.environ.get("OPENAI_API_KEY"))

In [4]:
llm

ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x0000027637B38190>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x0000027637B40E20>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_proxy='')

## Agent without few-shot examples

In [5]:
from langchain_community.agent_toolkits import create_sql_agent

In [6]:
agent = create_sql_agent(
    llm=llm,
    db=db,
    verbose=True,
    agent_type="openai-tools",
    stream_runnable = False
)

In [7]:
print(agent.agent.runnable.get_prompts()[0].messages[0].content)

You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't k

In [8]:
from langchain.callbacks import get_openai_callback

In [9]:
with get_openai_callback() as cb:
    out = agent.invoke({"input": "How many artists are there?"})
print(out)
print(cb)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Artist'}`


[0m[33;1m[1;3m
CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) AS TotalArtists FROM Artist'}`


[0m[36;1m[1;3m[(275,)][0m[32;1m[1;3mThere are 275 artists in the database.[0m

[1m> Finished chain.[0m
{'input': 'How many artists are there?', 'output': 'There are 275 artists in the database.'}
Tokens Used: 2702
	Prompt Tokens: 2641
	Completion Tokens: 61
Successful Requests: 4
Total Cost (USD): $0.0040835


## Agent with few-shot examples

In [10]:
examples = [
    {   "input": "List all artists.", 
        "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of Albums.",
        "query": "SELECT COUNT(DISTINT(AlbumId)) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]
print(len(examples))

10


In [12]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

In [13]:
system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    examples = examples,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix=""
)

In [14]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate

In [15]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

In [16]:
print(full_prompt.messages[0].to_json()['kwargs'])

{'prompt': FewShotPromptTemplate(input_variables=['dialect', 'top_k'], examples=[{'input': 'List all artists.', 'query': 'SELECT * FROM Artist;'}, {'input': "Find all albums for the artist 'AC/DC'.", 'query': "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');"}, {'input': "List all tracks in the 'Rock' genre.", 'query': "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');"}, {'input': 'Find the total duration of all tracks.', 'query': 'SELECT SUM(Milliseconds) FROM Track;'}, {'input': 'List all customers from Canada.', 'query': "SELECT * FROM Customer WHERE Country = 'Canada';"}, {'input': 'How many tracks are there in the album with ID 5?', 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'}, {'input': 'Find the total number of Albums.', 'query': 'SELECT COUNT(DISTINT(AlbumId)) FROM Invoice;'}, {'input': 'List all tracks that are longer than 5 minutes.', 'query': 'SELECT * FROM Track WHERE Milliseconds > 300

In [17]:
prompt_val = full_prompt.invoke(
    {
        "input": "How many arists are there",
        "top_k": 5,
        "dialect": "SQLite",
        "agent_scratchpad": [],
    }
)
print(prompt_val.to_string())

System: You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't 

In [18]:
agent_with_few_shots = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
    stream_runnable = False
)

In [19]:
with get_openai_callback() as cb:
    out = agent_with_few_shots.invoke({"input": "How many artists are there?"})
print(out)
print(cb)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist;'}`


[0m[36;1m[1;3m[(275,)][0m[32;1m[1;3mThere are 275 artists in the database.[0m

[1m> Finished chain.[0m
{'input': 'How many artists are there?', 'output': 'There are 275 artists in the database.'}
Tokens Used: 1731
	Prompt Tokens: 1701
	Completion Tokens: 30
Successful Requests: 2
Total Cost (USD): $0.0026115


## Agent with Dynamic few-shot examples

In [20]:
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=os.environ.get("OPENAI_API_KEY"))

In [22]:
len(embeddings.embed_query("Hi"))

1536

In [23]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    FAISS,
    k=3,
    input_keys=["input"],
    )

In [24]:
example_selector.vectorstore.search("How many arists are there?", search_type = "similarity")

[Document(page_content='List all artists.', metadata={'input': 'List all artists.', 'query': 'SELECT * FROM Artist;'}),
 Document(page_content='Find the total number of Albums.', metadata={'input': 'Find the total number of Albums.', 'query': 'SELECT COUNT(DISTINT(AlbumId)) FROM Invoice;'}),
 Document(page_content='How many employees are there', metadata={'input': 'How many employees are there', 'query': 'SELECT COUNT(*) FROM "Employee"'}),
 Document(page_content='How many tracks are there in the album with ID 5?', metadata={'input': 'How many tracks are there in the album with ID 5?', 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'})]

In [25]:
dynamic_few_shot_prompt = FewShotPromptTemplate(
    example_selector = example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix=""
)

In [26]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=dynamic_few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

In [27]:
prompt_val = full_prompt.invoke(
    {
        "input": "How many arists are there?",
        "top_k": 5,
        "dialect": "SQLite",
        "agent_scratchpad": [],
    }
)
print(prompt_val.to_string())

System: You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't 

In [28]:
dynamic_few_shot_agent = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
    stream_runnable = False
)

In [29]:
with get_openai_callback() as cb:
    out = dynamic_few_shot_agent.invoke({"input": "How many artists are there?"})
print(out)
print(cb)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist;'}`


[0m[36;1m[1;3m[(275,)][0m[32;1m[1;3mThere are 275 artists in the database.[0m

[1m> Finished chain.[0m
{'input': 'How many artists are there?', 'output': 'There are 275 artists in the database.'}
Tokens Used: 1311
	Prompt Tokens: 1281
	Completion Tokens: 30
Successful Requests: 2
Total Cost (USD): $0.0019815
