1. Configure Langsmith
2. Import LLM
3. Import Data
4. Dynamic few-shot prompt
5. Custom SQL Tools
6. ReAct Agent Executor
7. Persistent Memory
8. Showcase in Gradio UI

Langsmith Configuration

In [1]:
import os
langsmith_api_key = os.environ.get("LANGSMITH_API_KEY")

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Local SQL Agent"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key

LLM

In [2]:
from langchain_ollama import ChatOllama

llm = ChatOllama(model="llama3.1:8b-instruct-q4_0")  #"llama3.1"

Database

In [3]:
from langchain_community.utilities import SQLDatabase
# db = SQLDatabase.from_uri("sqlite:///chinook.db", sample_rows_in_table_info = 3)

server = 'localhost'
database = 'chinook'
driver = 'ODBC Driver 17 for SQL Server'

# Create the connection URL
connection_url = f'mssql+pyodbc://{server}/{database}?driver={driver}&Trusted_Connection=yes'

db = SQLDatabase.from_uri(connection_url, sample_rows_in_table_info = 3)

In [4]:
print(db.table_info)


CREATE TABLE [Album] (
	[AlbumId] INTEGER NOT NULL, 
	[Title] NVARCHAR(160) COLLATE SQL_Latin1_General_CP1_CI_AS NOT NULL, 
	[ArtistId] INTEGER NOT NULL, 
	CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]), 
	CONSTRAINT [FK_AlbumArtistId] FOREIGN KEY([ArtistId]) REFERENCES [Artist] ([ArtistId])
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE [Artist] (
	[ArtistId] INTEGER NOT NULL, 
	[Name] NVARCHAR(120) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	CONSTRAINT [PK_Artist] PRIMARY KEY ([ArtistId])
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE [Customer] (
	[CustomerId] INTEGER NOT NULL, 
	[FirstName] NVARCHAR(40) COLLATE SQL_Latin1_General_CP1_CI_AS NOT NULL, 
	[LastName] NVARCHAR(20) COLLATE SQL_Latin1_General_CP1_CI_AS NOT NULL, 
	[Company] NVARCHAR(80) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, 
	[Address] NVARCHAR(70) COLLATE SQL_La

Few Shot Examples

In [8]:
# examples = [
#     {   "input": "List all artists.", 
#         "query": "SELECT * FROM Artist;"},
#     {
#         "input": "Find all albums for the artist 'AC/DC'.",
#         "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
#     },
#     {
#         "input": "List all tracks in the 'Rock' genre.",
#         "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
#     },
#     {
#         "input": "Find the total duration of all tracks.",
#         "query": "SELECT SUM(Milliseconds) FROM Track;",
#     },
#     {
#         "input": "List all customers from Canada.",
#         "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
#     },
#     {
#         "input": "How many tracks are there in the album with ID 5?",
#         "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
#     },
#     {
#         "input": "Find the total number of Albums.",
#         "query": "SELECT COUNT(DISTINT(AlbumId)) FROM Invoice;",
#     },
#     {
#         "input": "List all tracks that are longer than 5 minutes.",
#         "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
#     },
#     {
#         "input": "Who are the top 5 customers by total purchase?",
#         "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
#     },
#     {
#         "input": "How many employees are there",
#         "query": 'SELECT COUNT(*) FROM "Employee"',
#     },
# ]

examples = [
    {   "input": "List all artists.", 
        "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of Albums.",
        "query": "SELECT COUNT(DISTINCT AlbumId) FROM Album;",  # Corrected the query
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT TOP 5 CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC;",  # Replaced LIMIT with TOP
    },
    {
        "input": "How many employees are there",
        "query": "SELECT COUNT(*) FROM Employee;",  # Removed quotes around table name
    },
]
print(len(examples))

10


Dynamic Example Selector

In [9]:
from langchain_huggingface import HuggingFaceEmbeddings

In [10]:
embeddings = HuggingFaceEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2')

In [11]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    FAISS,
    k=2,
    input_keys=["input"],
    )

example_selector.vectorstore.search("How many arists are there?", search_type = "mmr")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


[Document(metadata={'input': 'How many tracks are there in the album with ID 5?', 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'}, page_content='How many tracks are there in the album with ID 5?'),
 Document(metadata={'input': 'How many employees are there', 'query': 'SELECT COUNT(*) FROM Employee;'}, page_content='How many employees are there'),
 Document(metadata={'input': 'Who are the top 5 customers by total purchase?', 'query': 'SELECT TOP 5 CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC;'}, page_content='Who are the top 5 customers by total purchase?'),
 Document(metadata={'input': 'Find the total duration of all tracks.', 'query': 'SELECT SUM(Milliseconds) FROM Track;'}, page_content='Find the total duration of all tracks.')]

Prompt

In [16]:
# system_prefix = """You are an agent designed to interact with a SQL database.
# Given an input question, create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
# Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
# You can order the results by a relevant column to return the most interesting examples in the database.
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.

# You have access to the following tools for interacting with the database:

# {tools}

# Use the following format:

# Question: the input question you must answer
# Thought: you should always think about what to do
# Action: the action to take, should be one of {tool_names}
# Action Input: the input to the action
# Observation: the result of the action
# ... (this Thought/Action/Action Input/Observation can repeat N times)
# Thought: I now know the final answer
# Final Answer: the final answer to the original input question

# You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

# If the question does not seem related to the database, just return "I don't know" as the answer.
# If you see you are repeating yourself, just provide final answer and exit.

# Here are some examples of user inputs and their corresponding SQL queries:"""


system_prefix = """You are an agent designed to interact with a Microsoft SQL database.
Given an input question, create a syntactically correct MS SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

You have access to the following tools for interacting with the database:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of {tool_names}
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.
If you see you are repeating yourself, just provide final answer and exit.

Here are some examples of user inputs and their corresponding SQL queries:"""

In [17]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

dynamic_few_shot_prompt = FewShotPromptTemplate(
    example_selector = example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input"],
    prefix=system_prefix,
    suffix=""
)

In [18]:
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate

full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=dynamic_few_shot_prompt),
        ("human", "{input}"),
        ("system", "{agent_scratchpad}"),
    ]
)

Custom Tools

In [19]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool, InfoSQLDatabaseTool, ListSQLDatabaseTool, QuerySQLCheckerTool

tools = [QuerySQLDataBaseTool(db = db), InfoSQLDatabaseTool(db = db), ListSQLDatabaseTool(db = db), QuerySQLCheckerTool(db = db, llm = llm)]
print(QuerySQLDataBaseTool(db = db).description)

prompt_val = full_prompt.invoke(
    {
        "input": "How many arists are there?",
        "tool_names" : [tool.name for tool in tools],
        "tools" : [tool.name + " - " + tool.description.strip() for tool in tools],
        "agent_scratchpad": [],
    }
)

print(prompt_val.to_string())


    Execute a SQL query against the database and get back the result..
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    
System: You are an agent designed to interact with a Microsoft SQL database.
Given an input question, create a syntactically correct MS SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

You have access to the following tools for interacting with the database:

['sql_db_query - Execute a SQL query against the database and get back the result..\n    If the query is not correct, an error message will be retur

In [26]:
tools = [QuerySQLDataBaseTool(db = db), InfoSQLDatabaseTool(db = db), ListSQLDatabaseTool(db = db), QuerySQLCheckerTool(db = db, llm = llm)]
print("QuerySQLDataBaseTool", QuerySQLDataBaseTool(db = db).description)
print("InfoSQLDatabaseTool", InfoSQLDatabaseTool(db = db).description)
print("ListSQLDatabaseTool", ListSQLDatabaseTool(db = db).description)
print("QuerySQLCheckerTool", QuerySQLCheckerTool(db = db, llm = llm).description)

QuerySQLDataBaseTool 
    Execute a SQL query against the database and get back the result..
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    
InfoSQLDatabaseTool Get the schema and sample rows for the specified SQL tables.
ListSQLDatabaseTool Input is an empty string, output is a comma-separated list of tables in the database.
QuerySQLCheckerTool 
    Use this tool to double check if your query is correct before executing it.
    Always use this tool before executing a query with sql_db_query!
    


Agent Executor

In [23]:
from langchain.agents import AgentExecutor, create_react_agent
agent = create_react_agent(llm, tools, full_prompt)

agent_executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True, verbose = True)

History Management

In [24]:
last_k_messages = 4


from langchain_community.chat_message_histories import SQLChatMessageHistory

def get_session_history(session_id):
    chat_message_history = SQLChatMessageHistory(
    session_id=session_id, connection = "sqlite:///memory.db", table_name = "local_table"
    )

    messages = chat_message_history.get_messages()
    chat_message_history.clear()
    
    for message in messages[-last_k_messages:]:
        chat_message_history.add_message(message)
    
    print("chat_message_history ", chat_message_history)
    return chat_message_history


from langchain_core.runnables.history import RunnableWithMessageHistory

agent_with_chat_history = RunnableWithMessageHistory(
    agent_executor,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
)

Gradio UI

In [25]:
import gradio as gr
import uuid


with gr.Blocks() as demo:
    
    state = gr.State("")
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.ClearButton([msg, chatbot])


    def respond(message, chatbot_history, session_id):
        if not chatbot_history:
            session_id = uuid.uuid4().hex

        print("Session ID: ", session_id)

        response = agent_with_chat_history.invoke(
                                        {"input": message},
                                        {"configurable": {"session_id": session_id}},
                                        )

        chatbot_history.append((message, response['output']))
        return "", chatbot_history, session_id

    msg.submit(respond, [msg, chatbot, state], [msg, chatbot, state])

demo.launch()


Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.




Session ID:  e0fbcebbed2c4f10846f08b76e41da56
chat_message_history  


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: I need to create a SQL query to find out how many artists there are in the database. I will use the sql_db_query_checker tool to double check if my query is correct.
Action: sql_db_query_checker
Action Input: SELECT COUNT(*) FROM Artist[0m[36;1m[1;3mSELECT COUNT(*) FROM Artist[0m[32;1m[1;3mAction: sql_db_query
Action Input: SELECT COUNT(*) FROM Artist ORDER BY ArtistId LIMIT 5[0m[36;1m[1;3mError: (pyodbc.ProgrammingError) ('42000', "[42000] [Microsoft][ODBC Driver 17 for SQL Server][SQL Server]Incorrect syntax near 'LIMIT'. (102) (SQLExecDirectW)")
[SQL: SELECT COUNT(*) FROM Artist ORDER BY ArtistId LIMIT 5]
(Background on this error at: https://sqlalche.me/e/20/f405)[0m[32;1m[1;3mAction: sql_db_query_checker
Action Input: SELECT COUNT(*) FROM Artist[0m[36;1m[1;3mSELECT COUNT(*) FROM Artist[0m[32;1m[1;3mThought: I should remove the