# txt2sql agent

1.1 插入数据

In [6]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    Float,
    insert,
)
from urllib.parse import quote

# 修改这里：使用 pymysql 而不是 aiomysql
MYSQL_URL = f"mysql+pymysql://root:{quote('ffkj1314')}@192.168.10.60:3306/shenye_cost?charset=utf8mb4"
SQL_ECHO = False

# 同步引擎
engine = create_engine(MYSQL_URL, echo=SQL_ECHO, pool_pre_ping=True)

metadata_obj = MetaData()

In [8]:


# 创建表
table_name = "receipts"
receipts = Table(
    table_name,
    metadata_obj,
    Column("receipt_id", Integer, primary_key=True),
    Column("customer_name", String(16)),
    Column("price", Float),
    Column("tip", Float),
)

# 创建表
metadata_obj.create_all(engine)

# 插入数据
rows = [
    {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
    {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
    {"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
    {"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
]

with engine.begin() as connection:
    for row in rows:
        stmt = insert(receipts).values(**row)
        connection.execute(stmt)

print("✅ 数据插入成功！")

✅ 数据插入成功！


# 构建 agent

In [7]:
from sqlalchemy import  inspect
inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]

table_description = "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
print(table_description)

Columns:
  - receipt_id: INTEGER
  - customer_name: VARCHAR(16)
  - price: FLOAT
  - tip: FLOAT


1.2 构建工具

args:列出参数的docstring

输入输出 type hints


In [8]:
from smolagents import tool
from sqlalchemy import text

@tool
def sql_engine(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
    The table is named 'receipts'. Its description is as follows:
        Columns:
        - receipt_id: INTEGER
        - customer_name: VARCHAR(16)
        - price: FLOAT
        - tip: FLOAT

    Args:
        query: The query to perform. This should be correct SQL.
    """
    output = ""
    with engine.connect() as con:
        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output

In [11]:
from smolagents import CodeAgent, InferenceClientModel
import os


# 设置代理
os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"

# 或者设置 HuggingFace Token（如果有）
os.environ["HF_TOKEN"] = "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"
agent = CodeAgent(
    tools=[sql_engine],
    model=InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)
agent.run("你能告诉我获得最昂贵收据的客户的名字吗？")

'Woodrow Wilson'

# level 2:表连接

In [12]:
table_name = "waiters"
receipts = Table(
    table_name,
    metadata_obj,
    Column("receipt_id", Integer, primary_key=True),
    Column("waiter_name", String(16), primary_key=True),
)
metadata_obj.create_all(engine)

rows = [
    {"receipt_id": 1, "waiter_name": "Corey Johnson"},
    {"receipt_id": 2, "waiter_name": "Michael Watts"},
    {"receipt_id": 3, "waiter_name": "Michael Watts"},
    {"receipt_id": 4, "waiter_name": "Margaret James"},
]
with engine.begin() as connection:
    for row in rows:
        stmt = insert(receipts).values(**row)
        with engine.begin() as connection:
            cursor = connection.execute(stmt)


print("✅ 数据插入成功！")

✅ 数据插入成功！


In [13]:
updated_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:"""

inspector = inspect(engine)
for table in ["receipts", "waiters"]:
    columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]

    table_description = f"Table '{table}':\n"

    table_description += "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
    updated_description += "\n\n" + table_description

print(updated_description)

Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:

Table 'receipts':
Columns:
  - receipt_id: INTEGER
  - customer_name: VARCHAR(16)
  - price: FLOAT
  - tip: FLOAT

Table 'waiters':
Columns:
  - receipt_id: INTEGER
  - waiter_name: VARCHAR(16)


In [14]:
sql_engine.description = updated_description

In [18]:
# 查看所有属性
print("=== Tool 对象的所有属性 ===")
for attr in dir(sql_engine):
    if not attr.startswith('_'):  # 过滤私有属性
        print(f"{attr}: {getattr(sql_engine, attr)}")

=== Tool 对象的所有属性 ===
description: Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:

Table 'receipts':
Columns:
  - receipt_id: INTEGER
  - customer_name: VARCHAR(16)
  - price: FLOAT
  - tip: FLOAT

Table 'waiters':
Columns:
  - receipt_id: INTEGER
  - waiter_name: VARCHAR(16)
forward: <function sql_engine at 0x000001A4ACFD4D60>
from_code: <bound method Tool.from_code of <class 'smolagents.tools.tool.<locals>.SimpleTool'>>
from_dict: <bound method Tool.from_dict of <class 'smolagents.tools.tool.<locals>.SimpleTool'>>
from_gradio: <function Tool.from_gradio at 0x000001A4A9B16980>
from_hub: <bound method Tool.from_hub of <class 'smolagents.tools.tool.<locals>.SimpleTool'>>
from_langchain: <function Tool.from_langchain at 0x000001A4A9B16A20>
from_space: <function Tool.from_space at 0x000001A4A9B168E0>
inputs: {'query': {'type': 'string', 'description': 'The query to perform. 

In [19]:
print(f"名称: {sql_engine.name}")
print(f"描述:\n{sql_engine.description}")
print(f"输入类型: {sql_engine.inputs}")
print(f"输出类型: {sql_engine.output_type}")
print(f"原始函数: {sql_engine.forward}")

名称: sql_engine
描述:
Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:

Table 'receipts':
Columns:
  - receipt_id: INTEGER
  - customer_name: VARCHAR(16)
  - price: FLOAT
  - tip: FLOAT

Table 'waiters':
Columns:
  - receipt_id: INTEGER
  - waiter_name: VARCHAR(16)
输入类型: {'query': {'type': 'string', 'description': 'The query to perform. This should be correct SQL.'}}
输出类型: string
原始函数: <function sql_engine at 0x000001A4ACFD4D60>


In [20]:


agent = CodeAgent(
    tools=[sql_engine],
    model=InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)

agent.run("哪位服务员从小费中获得的总金额更多？")

'Michael Watts'