# LLM-Powered SQL Query Assistant using Llama and LangChain

**Objectives:**
1. Create a SQL agent that can query a database
2. Use Langsmith to trace the agent
3. Use Langsmith to evaluate the agent
3. Evaluate the agent with metrics?
4. Visualise the output with dataframes, plots
5. Create a chat interface for the agent

## Install Dependencies


In [1]:
# !pip install -U langchain langchain_community langchain-ollama langchain_huggingface faiss-cpu pymysql pandas plotly nbformat gradio

## Import Ollama and set up LangSmith Integration


In [2]:
# Environment variables must be set before agent is created
from dotenv import load_dotenv
import os

# Load environment variables from .env file
env_path = './.env'
load_dotenv(dotenv_path=env_path)

# Verify the environment variables are set
print("LANGCHAIN_TRACING_V2:", os.getenv("LANGCHAIN_TRACING_V2"))
print("LANGCHAIN_API_KEY:", "***" if os.getenv("LANGCHAIN_API_KEY") else "Not set")
print("LANGCHAIN_PROJECT:", os.getenv("LANGCHAIN_PROJECT"))

LANGCHAIN_TRACING_V2: true
LANGCHAIN_API_KEY: ***
LANGCHAIN_PROJECT: pr-ajar-upward-57


In [3]:
# Import Ollama and test connection
from langchain_ollama import ChatOllama

# Create agent
llm = ChatOllama(model='llama3', base_url='http://localhost:11434')
llm.invoke('Hello, world!') # The run should be traced and appear in LangSmith

AIMessage(content="Hello there! It's nice to meet you. How can I help or chat with you today?", additional_kwargs={}, response_metadata={'model': 'llama3', 'created_at': '2024-12-20T12:09:56.809801Z', 'done': True, 'done_reason': 'stop', 'total_duration': 3428501500, 'load_duration': 2706176500, 'prompt_eval_count': 14, 'prompt_eval_duration': 171000000, 'eval_count': 21, 'eval_duration': 549000000, 'message': Message(role='assistant', content='', images=None, tool_calls=None)}, id='run-f0b8499b-e842-48de-87d1-e51c6a40905b-0', usage_metadata={'input_tokens': 14, 'output_tokens': 21, 'total_tokens': 35})

## Create SQL agent

In [4]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("mysql+pymysql://root:admin@localhost/test_data")

# Inspect the database
print("Dialect:", db.dialect)
print("Tables:", db.get_usable_table_names())

# Test query
db.run("select * from sales")


Dialect: mysql
Tables: ['sales']


"[(1, 'Laptop', 2, Decimal('1200.00'), datetime.date(2024, 1, 15)), (2, 'Headphones', 5, Decimal('200.00'), datetime.date(2024, 2, 10)), (3, 'Monitor', 3, Decimal('300.00'), datetime.date(2024, 2, 20)), (4, 'Keyboard', 10, Decimal('50.00'), datetime.date(2024, 3, 5)), (5, 'Mouse', 8, Decimal('25.00'), datetime.date(2024, 3, 10)), (6, 'Smartphone', 4, Decimal('800.00'), datetime.date(2024, 4, 1)), (7, 'Tablet', 6, Decimal('500.00'), datetime.date(2024, 4, 15)), (8, 'Printer', 2, Decimal('150.00'), datetime.date(2024, 5, 5)), (9, 'Scanner', 1, Decimal('100.00'), datetime.date(2024, 5, 15)), (10, 'Camera', 3, Decimal('750.00'), datetime.date(2024, 6, 1))]"

In [5]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase
from langchain.agents import AgentExecutor

# Create SQL agent
agent_executor = create_sql_agent(
    llm=llm,
    db=db,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True,
    agent_executor_kwargs = {"return_intermediate_steps": True} # This will allow us to get access to the SQL query
)

# Run the agent with a query
response = agent_executor.invoke(
    "What are all the sales records in the database?"
)

print(response['output'])




[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mLet's start by listing the tables in the database.

Action: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3msales[0m[32;1m[1;3mThought: Now that I have a list of tables, I should query the schema of the table related to sales records. Let's check the schema of the "sales" table.

Action: sql_db_schema
Action Input: sales[0m[33;1m[1;3m
CREATE TABLE sales (
	sale_id INTEGER NOT NULL AUTO_INCREMENT, 
	product_name VARCHAR(255), 
	quantity INTEGER, 
	price_per_unit DECIMAL(10, 2), 
	sale_date DATE, 
	PRIMARY KEY (sale_id)
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from sales table:
sale_id	product_name	quantity	price_per_unit	sale_date
1	Laptop	2	1200.00	2024-01-15
2	Headphones	5	200.00	2024-02-10
3	Monitor	3	300.00	2024-02-20
*/[0m[32;1m[1;3mLet's think about what to do next...

Action: sql_db_query
Action Input: SELECT * FROM sales LIMIT 10[0m[36;1m[1;3m[(1, 'Laptop', 2

## Visualise Query Output

In [10]:
# Filter the log based on tool name to get SQL query
def get_query(response):
    queries = []
    for (log, output) in response["intermediate_steps"]:
        if log.tool == 'sql_db_query':
            queries.append(log.tool_input)

    return queries[0] # return first one

get_query(response)

'SELECT * FROM sales LIMIT 10'

In [7]:
import pandas as pd

df = pd.read_sql_query(queries[0], db._engine)

df

Unnamed: 0,sale_id,product_name,quantity,price_per_unit,sale_date
0,1,Laptop,2,1200.0,2024-01-15
1,2,Headphones,5,200.0,2024-02-10
2,3,Monitor,3,300.0,2024-02-20
3,4,Keyboard,10,50.0,2024-03-05
4,5,Mouse,8,25.0,2024-03-10
5,6,Smartphone,4,800.0,2024-04-01
6,7,Tablet,6,500.0,2024-04-15
7,8,Printer,2,150.0,2024-05-05
8,9,Scanner,1,100.0,2024-05-15
9,10,Camera,3,750.0,2024-06-01


In [13]:
import plotly.express as px

def plot_sales_hist():
# 1. Bar chart of total sales by product
   sales_by_product = df.assign(total_sales=df['quantity'] * df['price_per_unit'])\
      .groupby('product_name')['total_sales'].sum().reset_index()
   fig1 = px.bar(sales_by_product, 
               x='product_name', 
               y='total_sales',
               title='Total Sales by Product',
               labels={'total_sales': 'Total Sales ($)', 'product_name': 'Product'})
   # fig1.show()
   return fig1

plot_sales_hist()

## Create Agent Interface

In [15]:
# Combine the query filtering and visualisation processing into a callable function
def process_query(question):
    """Process the user's question and return response with visualization"""
    # Get response from SQL agent
    response = agent_executor.invoke(question)

    # Get the SQL query from intermediate steps
    query = get_query(response)

    # Run query and assign output to a DataFrame
    df = pd.read_sql_query(query, db._engine)

    # Create a hist plot
    fig = plot_sales_hist()

    return response['output'], fig

process_query("What are the total sales?")



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mLet's start by listing all the tables in the database.

Action: sql_db_list_tables
Action Input: empty string[0m[38;5;200m[1;3msales[0m[32;1m[1;3mThought: Now that I have a list of tables, I can look at the schema to see what columns are available. This will help me decide what query to run.
Action: sql_db_schema
Action Input: sales[0m[33;1m[1;3m
CREATE TABLE sales (
	sale_id INTEGER NOT NULL AUTO_INCREMENT, 
	product_name VARCHAR(255), 
	quantity INTEGER, 
	price_per_unit DECIMAL(10, 2), 
	sale_date DATE, 
	PRIMARY KEY (sale_id)
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from sales table:
sale_id	product_name	quantity	price_per_unit	sale_date
1	Laptop	2	1200.00	2024-01-15
2	Headphones	5	200.00	2024-02-10
3	Monitor	3	300.00	2024-02-20
*/[0m[32;1m[1;3mThought: Now that I have the schema of the sales table, I can see what columns are available and decide on a query to run.

Act

('The final answer to the original input question is 13850.00.',
 Figure({
     'data': [{'alignmentgroup': 'True',
               'hovertemplate': 'Product=%{x}<br>Total Sales ($)=%{y}<extra></extra>',
               'legendgroup': '',
               'marker': {'color': '#636efa', 'pattern': {'shape': ''}},
               'name': '',
               'offsetgroup': '',
               'orientation': 'v',
               'showlegend': False,
               'textposition': 'auto',
               'type': 'bar',
               'x': array(['Camera', 'Headphones', 'Keyboard', 'Laptop', 'Monitor', 'Mouse',
                           'Printer', 'Scanner', 'Smartphone', 'Tablet'], dtype=object),
               'xaxis': 'x',
               'y': array([2250., 1000.,  500., 2400.,  900.,  200.,  300.,  100., 3200., 3000.]),
               'yaxis': 'y'}],
     'layout': {'barmode': 'relative',
                'legend': {'tracegroupgap': 0},
                'template': '...',
                'title': {

In [16]:
import gradio as gr

demo = gr.Interface(
    fn=process_query,
    inputs=gr.Textbox(label="Ask a question about the data"),
    outputs=[
        gr.Textbox(label="Response"),
        gr.Plot(label="Visualisation")
    ],
    title="Sales Data Analysis Assistant",
    description="Ask questions about the sales data and get answers with visualisations",
    examples=["What are the total sales for each product?"]
)

# Launch the interface
demo.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.






[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mLet's get started.

Thought: Since we're looking for total sales for each product, it seems like we need to interact with a table that contains both product information and sale data.

Action: sql_db_list_tables
Action Input: empty string (as this is the default input)[0m[38;5;200m[1;3msales[0m[32;1m[1;3mLet's continue!

Thought: Since "sales" seems relevant for our question, let's take a look at its schema to see what columns it has and what kind of information it contains.

Action: sql_db_schema
Action Input: sales[0m[33;1m[1;3m
CREATE TABLE sales (
	sale_id INTEGER NOT NULL AUTO_INCREMENT, 
	product_name VARCHAR(255), 
	quantity INTEGER, 
	price_per_unit DECIMAL(10, 2), 
	sale_date DATE, 
	PRIMARY KEY (sale_id)
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from sales table:
sale_id	product_name	quantity	price_per_unit	sale_date
1	Laptop	2	1200.00	2024-01-15
2	Headphones	5	200.00

Traceback (most recent call last):
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\gradio\queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\gradio\route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\gradio\blocks.py", line 2047, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\gradio\blocks.py", line 1594, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\anyio\to_thread.py", line 56, in run_sync
    r



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mLet's get started.

Thought: I think there might be a table like "products" or "inventory" that has information about products, and maybe another table like "sales" or "orders" that has sales data. Let me check the database first to see what tables are available.

Action: sql_db_list_tables
Action Input: empty string[0m[38;5;200m[1;3msales[0m[32;1m[1;3mThought: Okay, so it looks like there is a table called "sales". Now I should think about which columns in this table would be relevant for answering the question about total sales for each product.

Action: sql_db_schema
Action Input: sales[0m[33;1m[1;3m
CREATE TABLE sales (
	sale_id INTEGER NOT NULL AUTO_INCREMENT, 
	product_name VARCHAR(255), 
	quantity INTEGER, 
	price_per_unit DECIMAL(10, 2), 
	sale_date DATE, 
	PRIMARY KEY (sale_id)
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from sales table:
sale_id	product_name	quantity	pr

Traceback (most recent call last):
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\gradio\queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\gradio\route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\gradio\blocks.py", line 2047, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\gradio\blocks.py", line 1594, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Ahmed\.conda\envs\sql-agent\Lib\site-packages\anyio\to_thread.py", line 56, in run_sync
    r

## Create few shot examples
This will help save tokens

In [26]:
# Create some few shot examples

examples = [
    {
        "input": "What are all the sales records in the database?",
        "output": "SELECT * FROM sales"
    }
]

In [24]:
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-l6-v2')

In [31]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples, # Few shot examples
    embeddings, # Embedding model
    FAISS, # Vector store
    k=2, # Number of examples to return
    input_keys=["input"] # Input keys from examples
)

example_selector.vectorstore.search("How many sales are there?", search_type="mmr")


[Document(id='2c231c85-c9c1-4024-9012-1b9b63abe09a', metadata={'input': 'What are all the sales records in the database?', 'output': 'SELECT * FROM sales'}, page_content='What are all the sales records in the database?')]