In [1]:
import getpass
import os
import time
from dotenv import load_dotenv

# Load the .env file
_ = load_dotenv()

### Model

In [2]:
from agno.models.ollama import Ollama

sql_ollama_model=Ollama(id="qwen2.5-coder:7b")

### Text 2 sql agent

In [3]:
from agno.agent import Agent

instructions="""
You are an expert in SQL. You can create sql queries from natural language following these schemas:

CREATE TABLE IF NOT EXISTS Customers(
        CustomerID INTEGER PRIMARY KEY AUTOINCREMENT,
        FirstName TEXT NOT NULL,
        LastName TEXT NOT NULL,
        Email TEXT UNIQUE NOT NULL,
        Phone TEXT,
        Address TEXT
);

-- Create Products table
CREATE TABLE IF NOT EXISTS Products (
        ProductID INTEGER PRIMARY KEY AUTOINCREMENT,
        Name TEXT NOT NULL,
        Description TEXT,
        Price REAL NOT NULL,
        StockQuantity INTEGER NOT NULL,
        Category TEXT NOT NULL
    );


-- Create Orders table
CREATE TABLE IF NOT EXISTS Orders (
        OrderID INTEGER PRIMARY KEY AUTOINCREMENT,
        CustomerID INTEGER NOT NULL,
        OrderDate DATE NOT NULL,
        TotalAmount REAL NOT NULL,
        FOREIGN KEY (CustomerID) REFERENCES Customers(CustomerID)
    );

-- Create OrderDetails table
CREATE TABLE IF NOT EXISTS OrderDetails (
        OrderDetailID INTEGER PRIMARY KEY AUTOINCREMENT,
        OrderID INTEGER NOT NULL,
        ProductID INTEGER NOT NULL,
        Quantity INTEGER NOT NULL,
        Subtotal REAL NOT NULL,
        FOREIGN KEY (OrderID) REFERENCES Orders(OrderID),
        FOREIGN KEY (ProductID) REFERENCES Products(ProductID)
    );

Return only the sql query. Don't explain the query and don't put the sql query in ```sql\n```.
"""
text2sql_agent = Agent(model=sql_ollama_model, name="text2sql", instructions=instructions, show_tool_calls=True)

### Tests

In [4]:
def write_sql_query(query:str):
    """ Function to write sql query """
    start_time = time.perf_counter()

    response = text2sql_agent.run(query)

    end_time = time.perf_counter()
    
    latency = end_time - start_time
    
    print(f"Execution time: {latency} seconds")
    answer = response.content
    
    return answer
    

In [5]:
query = "¿Cuáles son los detalles de todos los clientes?"
write_sql_query(query) #SELECT * FROM Customers;

Execution time: 109.26334430000861 seconds


'SELECT * FROM Customers;'

In [6]:
query = "¿Cuáles son los productos disponibles en la categoría 'Women'?"
write_sql_query(query)#"SELECT * FROM Products WHERE Category = 'Women';",

Execution time: 6.971404300013091 seconds


"SELECT Name FROM Products WHERE Category = 'Women'"

In [7]:
query = "¿Cuál es el precio total de todos los pedidos realizados por un cliente específico con ID 5?"
write_sql_query(query)#"SELECT SUM(TotalAmount) FROM Orders WHERE CustomerID = 5;",

Execution time: 11.981810800003586 seconds


'SELECT SUM(TotalAmount) FROM Orders WHERE CustomerID = 5'

In [8]:
query = "¿Cuántos productos están en stock con una cantidad mayor a 20 unidades?"
write_sql_query(query)#"SELECT * FROM Products WHERE StockQuantity > 20;"

Execution time: 10.258081999985734 seconds


'SELECT COUNT(*) FROM Products WHERE StockQuantity > 20'

In [9]:
query = "¿Qué productos han sido comprados en el pedido con ID 3?"
write_sql_query(query)#"SELECT p.Name, od.Quantity, od.Subtotal FROM OrderDetails od INNER JOIN Products p ON od.ProductID = p.ProductID WHERE od.OrderID = 3;",

Execution time: 14.596958100009942 seconds


'SELECT P.ProductID, P.Name FROM Products AS P JOIN OrderDetails AS OD ON P.ProductID = OD.ProductID WHERE OD.OrderID = 3;'

### Benchmark

In [10]:
import json
import glob

def load_benchmark(benchmark_path):
    """ This function loads the benchmark dataset"""
    questions = []
    expected_sql = []
    with open(benchmark_path, 'r', encoding='utf-8') as f:
        try:
            datos = json.load(f)
            questions.extend(datos.keys())
            expected_sql.extend(datos.values())
        except Exception as e:
            print(f,e)
    return questions, expected_sql

questions , expected_sql = load_benchmark("../tests/dataset_shop_queries.json")

In [11]:
questions

['¿Cuáles son los detalles de todos los clientes?',
 "¿Cuáles son los productos disponibles en la categoría 'Women'?",
 '¿Cuál es el precio total de todos los pedidos realizados por un cliente específico con ID 5?',
 '¿Cuántos productos están en stock con una cantidad mayor a 20 unidades?',
 '¿Qué productos han sido comprados en el pedido con ID 3?',
 "¿Cuál es el nombre y apellido del cliente que tiene el correo electrónico 'juan.perez@email.com'?",
 '¿Qué pedidos fueron realizados el 14 de febrero de 2025?',
 '¿Cuántos productos en total fueron comprados en el pedido con ID 1?',
 '¿Qué productos tienen un precio superior a 120?',
 '¿Cuáles son los detalles de los pedidos y sus productos asociados para el cliente con ID 2?']

In [12]:
expected_sql

['SELECT * FROM Customers;',
 "SELECT * FROM Products WHERE Category = 'Women';",
 'SELECT SUM(TotalAmount) FROM Orders WHERE CustomerID = 5;',
 'SELECT * FROM Products WHERE StockQuantity > 20;',
 'SELECT p.Name, od.Quantity, od.Subtotal FROM OrderDetails od INNER JOIN Products p ON od.ProductID = p.ProductID WHERE od.OrderID = 3;',
 "SELECT FirstName, LastName FROM Customers WHERE Email = 'juan.perez@email.com';",
 "SELECT * FROM Orders WHERE OrderDate = '2025-02-14';",
 'SELECT SUM(Quantity) FROM OrderDetails WHERE OrderID = 1;',
 'SELECT * FROM Products WHERE Price > 120;',
 'SELECT o.OrderID, o.OrderDate, p.Name, od.Quantity, od.Subtotal FROM Orders o INNER JOIN OrderDetails od ON o.OrderID = od.OrderID INNER JOIN Products p ON od.ProductID = p.ProductID WHERE o.CustomerID = 2;']

In [13]:
import pandas as pd
import time

def eval_benchmark(benchmark_path:str)->pd.DataFrame:
	""" This function evaluates the benchmark dataset"""

	# Load the benchmark dataset
	questions , expected_sql = load_benchmark(benchmark_path)
	# Create a dataframe to store the results
	df = pd.DataFrame({"question":questions, "expected_sql":expected_sql})

	latencies = []
	answers = []
	for index, row in df.iterrows():
		try:
			question = row["question"]
			expected_sql = row["expected_sql"]
			
			# Measure the latency
			start_time = time.time()
			# Invoke the agent
			response = text2sql_agent.run(query)
			predicted_sql = response.content
			# Measure the latency
			latency = time.time() - start_time
			latencies.append(latency)
			# Store the predicted SQL query
			answers.append(predicted_sql)
		except Exception as e:
			answers.append("ERROR")
			latencies.append(-1)
			print(f"Error processing question: {question}")

	# Store the results in the dataframe
	df["predicted_sql"] = answers
	df["latency"] = latencies
	return df

In [15]:
df_eval = eval_benchmark( "../tests/dataset_shop_queries.json")

In [17]:
df_eval.head(10)

Unnamed: 0,question,expected_sql,predicted_sql,latency
0,¿Cuáles son los detalles de todos los clientes?,SELECT * FROM Customers;,"SELECT P.ProductID, P.Name FROM Products P JOI...",11.027225
1,¿Cuáles son los productos disponibles en la ca...,SELECT * FROM Products WHERE Category = 'Women';,"SELECT p.Name, od.Quantity \nFROM OrderDetails...",12.204067
2,¿Cuál es el precio total de todos los pedidos ...,SELECT SUM(TotalAmount) FROM Orders WHERE Cust...,"SELECT P.ProductID, P.Name, OD.Quantity FROM O...",12.165747
3,¿Cuántos productos están en stock con una cant...,SELECT * FROM Products WHERE StockQuantity > 20;,"SELECT P.ProductID, P.Name, OD.Quantity FROM P...",13.884605
4,¿Qué productos han sido comprados en el pedido...,"SELECT p.Name, od.Quantity, od.Subtotal FROM O...","SELECT P.ProductID, P.Name, OD.Quantity FROM O...",14.324975
5,¿Cuál es el nombre y apellido del cliente que ...,"SELECT FirstName, LastName FROM Customers WHER...","SELECT P.ProductID, P.Name FROM Products P JOI...",12.782388
6,¿Qué pedidos fueron realizados el 14 de febrer...,SELECT * FROM Orders WHERE OrderDate = '2025-0...,"SELECT P.ProductID, P.Name FROM Products P JOI...",12.512871
7,¿Cuántos productos en total fueron comprados e...,SELECT SUM(Quantity) FROM OrderDetails WHERE O...,"SELECT P.ProductID, P.Name FROM Products P JOI...",11.449392
8,¿Qué productos tienen un precio superior a 120?,SELECT * FROM Products WHERE Price > 120;,"SELECT P.ProductID, P.Name, OD.Quantity FROM O...",12.63259
9,¿Cuáles son los detalles de los pedidos y sus ...,"SELECT o.OrderID, o.OrderDate, p.Name, od.Quan...","SELECT P.ProductID, P.Name, OD.Quantity, OD.Su...",12.51413


In [18]:
model_id = "qwen2.5-coder"

df_eval.to_csv(f"../tests/df_benchmark_shop_results_{model_id}.csv", index=False)

In [19]:
import pandas as pd

model_id = "qwen2.5-coder"

df_eval = pd.read_csv(f"../tests/df_benchmark_shop_results_{model_id}.csv")

In [20]:
import plotly.graph_objects as go

# Crear figura
fig = go.Figure(data=[go.Bar(x=df_eval.question, y=df_eval.latency)])

# Configurar ejes
fig.update_layout(
    title=f'Our Text2SQL Benchmark dataset to evaluate the performance of : {model_id}',
    xaxis_title='Question',
    yaxis_title='Latency (s)'
)
fig.update_layout(title=dict(font_size=12, font_weight='bold', x=0.5))
fig.update_traces(marker_line_color='#e69138', marker_line_width=2, marker_color="#f35a02")
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)', yaxis=dict(gridcolor='lightgrey'), height=600, width=1000)

# Mostrar figura
fig.show()