1. Import Llama3
2. Fetch data
3. Initialize few shot examples
4. convert to embeddings
5. few custom tools
6. Prompt
7. Agent with ReAct
8. Agent Executer

Import Llama3

In [2]:
from langchain_community.llms import Ollama

In [3]:
llm = Ollama(model = "llama3")

In [3]:
llm.invoke("Hi")

"Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?"

Get the Data

In [4]:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///chinook.db", sample_rows_in_table_info = 3)

In [5]:
print(db.table_info)


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

Initialize Few Shots

In [6]:
examples = [
    {   "input": "List all artists.", 
        "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of Albums.",
        "query": "SELECT COUNT(DISTINT(AlbumId)) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]
print(len(examples))

10


Few Shots to Embeddings

In [7]:
from langchain_community.embeddings import OllamaEmbeddings

In [8]:
embeddings = (
    OllamaEmbeddings(model = "llama3")
)

In [9]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    FAISS,
    k=3,
    input_keys=["input"],
    )

In [11]:
matched_queries = example_selector.vectorstore.search("How many arists are there?", search_type = "mmr")
print(matched_queries)

[Document(page_content='How many tracks are there in the album with ID 5?', metadata={'input': 'How many tracks are there in the album with ID 5?', 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'}), Document(page_content='List all artists.', metadata={'input': 'List all artists.', 'query': 'SELECT * FROM Artist;'}), Document(page_content='How many employees are there', metadata={'input': 'How many employees are there', 'query': 'SELECT COUNT(*) FROM "Employee"'})]


In [12]:
for doc in matched_queries:
    print(doc.page_content)

How many tracks are there in the album with ID 5?
List all artists.
How many employees are there


Custom Tools

In [13]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool, InfoSQLDatabaseTool, ListSQLDatabaseTool, QuerySQLCheckerTool

sql_db_query =  QuerySQLDataBaseTool(db = db)
sql_db_schema =  InfoSQLDatabaseTool(db = db)
sql_db_list_tables =  ListSQLDatabaseTool(db = db)
sql_db_query_checker = QuerySQLCheckerTool(db = db, llm = llm)

tools = [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker]

for tool in tools:
    print(tool.name + " - " + tool.description.strip() + "\n")

sql_db_query - Execute a SQL query against the database and get back the result..
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.

sql_db_schema - Get the schema and sample rows for the specified SQL tables.

sql_db_list_tables - Input is an empty string, output is a comma-separated list of tables in the database.

sql_db_query_checker - Use this tool to double check if your query is correct before executing it.
    Always use this tool before executing a query with sql_db_query!



Prompts

In [14]:
system_prefix = """
Answer the following questions as best you can. You have access to the following tools:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Here are some examples of user inputs and their corresponding SQL queries:

"""

In [15]:
suffix = """
Begin!

Question: {input}
Thought:{agent_scratchpad}
"""

In [17]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate, ChatPromptTemplate
from langchain_core.prompts import SystemMessagePromptTemplate

In [18]:
dynamic_few_shot_prompt_template = FewShotPromptTemplate(
    example_selector = example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input"],
    prefix=system_prefix,
    suffix=suffix
)

In [19]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=dynamic_few_shot_prompt_template),
    ]
)

In [20]:
prompt_val = full_prompt.invoke(
    {
        "input": "How many arists are there?",
        "tool_names" : [tool.name for tool in tools],
        "tools" : [tool.name + " - " + tool.description.strip() for tool in tools],
        "agent_scratchpad": [],
    }
)

print(prompt_val.to_string())

System: 
Answer the following questions as best you can. You have access to the following tools:

['sql_db_query - Execute a SQL query against the database and get back the result..\n    If the query is not correct, an error message will be returned.\n    If an error is returned, rewrite the query, check the query, and try again.', 'sql_db_schema - Get the schema and sample rows for the specified SQL tables.', 'sql_db_list_tables - Input is an empty string, output is a comma-separated list of tables in the database.', 'sql_db_query_checker - Use this tool to double check if your query is correct before executing it.\n    Always use this tool before executing a query with sql_db_query!']

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [['sql_db_query', 'sql_db_schema', 'sql_db_list_tables', 'sql_db_query_checker']]
Action Input: the input to the action
Observation: the

Create Agent

In [21]:
from langchain.agents import AgentExecutor, create_react_agent
agent = create_react_agent(llm, tools, full_prompt)

In [22]:
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)

In [None]:
import langchain
langchain.debug = True

In [26]:
agent_executor.invoke({"input": "Provide the number of customers with respect to each country."})

[32;1m[1;3m[chain/start][0m [1m[chain:AgentExecutor] Entering Chain run with input:
[0m{
  "input": "Provide the number of customers with respect to each country."
}
[32;1m[1;3m[chain/start][0m [1m[chain:AgentExecutor > chain:RunnableSequence] Entering Chain run with input:
[0m{
  "input": ""
}
[32;1m[1;3m[chain/start][0m [1m[chain:AgentExecutor > chain:RunnableSequence > chain:RunnableAssign<agent_scratchpad>] Entering Chain run with input:
[0m{
  "input": ""
}
[32;1m[1;3m[chain/start][0m [1m[chain:AgentExecutor > chain:RunnableSequence > chain:RunnableAssign<agent_scratchpad> > chain:RunnableParallel<agent_scratchpad>] Entering Chain run with input:
[0m{
  "input": ""
}
[32;1m[1;3m[chain/start][0m [1m[chain:AgentExecutor > chain:RunnableSequence > chain:RunnableAssign<agent_scratchpad> > chain:RunnableParallel<agent_scratchpad> > chain:RunnableLambda] Entering Chain run with input:
[0m{
  "input": ""
}
[36;1m[1;3m[chain/end][0m [1m[chain:AgentExecutor > c

In [None]:
agent_executor.invoke({"input": "How many artists are there?"})