A standard text-to-sql pipeline is brittle, since the generated SQL query can be incorrect. Even worse, the query could be incorrect, but not raise an error, instead giving some incorrect/useless outputs without raising an alarm.

👉 Instead, an agent system is able to critically inspect outputs and decide if the query needs to be changed or not, thus giving it a huge performance boost.

In [1]:
# setup sql db
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    Float,
    insert,
    inspect,
    text,
)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "receipts"
receipts = Table(
    table_name,
    metadata_obj,
    Column("receipt_id", Integer, primary_key=True),
    Column("customer_name", String(16), primary_key=True),
    Column("price", Float),
    Column("tip", Float),
)
metadata_obj.create_all(engine)

In [2]:
rows = [
    {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
    {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
    {"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
    {"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
]
for row in rows:
    stmt = insert(receipts).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [3]:
with engine.connect() as con:
  rows = con.execute(text("""SELECT * from receipts"""))
  for row in rows:
    print(row)

(1, 'Alan Payne', 12.06, 1.2)
(2, 'Alex Mason', 23.86, 0.24)
(3, 'Woodrow Wilson', 53.43, 5.43)
(4, 'Margaret James', 21.11, 1.0)


### Building the agent

The tool’s description attribute will be embedded in the LLM’s prompt by the agent system: it gives the LLM information about how to use the tool. So that is where we want to describe the SQL table.



In [4]:
inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]

table_description = "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
print(table_description)

Columns:
  - receipt_id: INTEGER
  - customer_name: VARCHAR(16)
  - price: FLOAT
  - tip: FLOAT


In [5]:
!pip install -q "transformers[agents]" --upgrade

In [11]:
from transformers.agents import tool

@tool
def sql_engine(query: str) -> str:
  """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
    The table is named 'receipts'.

    Args:
        query: The query to perform. This should be correct SQL.
  """
  output = ""
  with engine.connect() as con:
    rows = con.execute(text(query))
    for row in rows:
      output += "\n" + str(row)
  return output

We use the ReactCodeAgent, which is transformers.agents’ main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.



In [12]:
from transformers.agents import ReactCodeAgent, HfApiEngine

agent = ReactCodeAgent(
    tools=[sql_engine],
    llm_engine=HfApiEngine("meta-llama/Meta-Llama-3-8B-Instruct"),
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


### increasing difficulty

Now let’s make it more challenging! We want our agent to handle joins across multiple tables.



In [13]:
table_name = "waiters"
receipts = Table(
    table_name,
    metadata_obj,
    Column("receipt_id", Integer, primary_key=True),
    Column("waiter_name", String(16), primary_key=True),
)
metadata_obj.create_all(engine)

rows = [
    {"receipt_id": 1, "waiter_name": "Corey Johnson"},
    {"receipt_id": 2, "waiter_name": "Michael Watts"},
    {"receipt_id": 3, "waiter_name": "Michael Watts"},
    {"receipt_id": 4, "waiter_name": "Margaret James"},
]
for row in rows:
    stmt = insert(receipts).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [14]:
updated_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:"""

inspector = inspect(engine)
for table in ["receipts", "waiters"]:
    columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]

    table_description = f"Table '{table}':\n"

    table_description += "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
    updated_description += "\n\n" + table_description

print(updated_description)

Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:

Table 'receipts':
Columns:
  - receipt_id: INTEGER
  - customer_name: VARCHAR(16)
  - price: FLOAT
  - tip: FLOAT

Table 'waiters':
Columns:
  - receipt_id: INTEGER
  - waiter_name: VARCHAR(16)


In [16]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [17]:
sql_engine.description = updated_description

agent = ReactCodeAgent(
    tools=[sql_engine],
    llm_engine=HfApiEngine("Qwen/Qwen2.5-72B-Instruct"),
)

agent.run("which waiter got more total money from tips")

[37;1mwhich waiter got more total money from tips[0m
[33;1m=== Agent thoughts:[0m
[0mThought: To determine which waiter earned the most in tips, I need to join the `waiters` and `receipts` tables, group by `waiter_name`, and sum the `tip` column. Then, I'll find the waiter with the highest total tip.[0m
[33;1m>>> Agent is executing the code below:[0m
[0m[38;5;7mquery[39m[38;5;7m [39m[38;5;109;01m=[39;00m[38;5;7m [39m[38;5;144m"""[39m
[38;5;144mSELECT w.waiter_name, SUM(r.tip) as total_tips[39m
[38;5;144mFROM waiters w[39m
[38;5;144mJOIN receipts r ON w.receipt_id = r.receipt_id[39m
[38;5;144mGROUP BY w.waiter_name[39m
[38;5;144mORDER BY total_tips DESC[39m
[38;5;144mLIMIT 1;[39m
[38;5;144m"""[39m
[38;5;7mtop_waiter[39m[38;5;7m [39m[38;5;109;01m=[39;00m[38;5;7m [39m[38;5;7msql_engine[39m[38;5;7m([39m[38;5;7mquery[39m[38;5;109;01m=[39;00m[38;5;7mquery[39m[38;5;7m)[39m
[38;5;109mprint[39m[38;5;7m([39m[38;5;144m"[39m[38;5;144mTop 

'Michael Watts'