# 文本到 SQL

在数据世界中，除了向量数据库能够处理的非结构化数据，关系型数据库（如 MySQL, PostgreSQL, SQLite）同样是存储和管理结构化数据的重点。

Text-to-SQL 利用大语言模型（LLM）将用户的自然语言问题，直接翻译成可以在数据库上执行的SQL查询语句。

![](images/4_3_1.webp)

## 挑战

- “幻觉”问题：LLM 可能会“想象”出数据库中不存在的表或字段，导致生成的SQL语句无效。

- 对数据库结构理解不足：LLM 需要准确理解表的结构、字段的含义以及表与表之间的关联关系，才能生成正确的 JOIN 和 WHERE 子句。

- 处理用户输入的模糊性：用户的提问可能存在拼写错误或不规范的表达（例如，“上个月的销售冠军是谁？”），模型需要具备一定的容错和推理能力。

## 优化策略

- 提供精确的数据库模式：向LLM提供数据库中相关表的 CREATE TABLE 语句，让它了解数据库的结构，包括表名、列名、数据类型和外键关系。

- 提供少量高质量的示例：在提示（Prompt）中加入一些“问题-SQL”的示例对，让它学习如何根据相似的问题构建查询。

- 利用RAG增强上下文：为数据库构建一个专门的“知识库”，包含：

    - 表的DDL（数据定义语言）

    - 表和字段的详细描述：用自然语言解释每个表是做什么的，每个字段代表什么业务含义。

    - 同义词和业务术语，例如，将用户的“花费”映射到数据库的 cost 字段。

    - 复杂的查询示例：提供一些包含 JOIN、GROUP BY 或子查询的复杂问答对。 

当用户提问时，系统首先从这个知识库中检索最相关的信息（如相关的表结构、字段描述、相似的Q&A），

然后将这些信息和用户的问题一起组合成一个内容更丰富的提示，

交给LLM生成最终的SQL查询。

## 搭建 Text2SQL 框架

- 知识库：Milvus
  
- 语义检索模型：BGE-M3
  
- 大语言模型：Deepseek
  
- 数据库：SQLite

### 知识库模块

整个框架的核心，负责存储和检索SQL相关的知识信息。

In [1]:
import json
import os
from typing import List, Dict, Any
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
from pymilvus.model.hybrid import BGEM3EmbeddingFunction


class SimpleKnowledgeBase:
    """知识库"""
    
    def __init__(self, milvus_uri: str = "http://localhost:19530"):
        self.milvus_uri = milvus_uri
        self.client = MilvusClient(uri=milvus_uri)
        self.embedding_function = BGEM3EmbeddingFunction(model_name="models/bge/bge-m3", use_fp16=False, device="cpu")
        self.collection_name = "text2sql_kb"
        self._setup_collection()
    
    def _setup_collection(self):
        """设置集合"""
        if self.client.has_collection(self.collection_name):
            self.client.drop_collection(self.collection_name)
        
        # 定义字段
        fields = [
            FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100),
            FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096),
            # 通过 type 字段来区分不同类型的知识
            # ddl: 数据库定义语言
            # qsql: 问题到SQL的映射
            # description: 表的描述
            FieldSchema(name="type", dtype=DataType.VARCHAR, max_length=32),  # ddl, qsql, description
            FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=self.embedding_function.dim["dense"])
        ]
        
        schema = CollectionSchema(fields, description="Text2SQL知识库")
        
        # 创建集合
        self.client.create_collection(
            collection_name=self.collection_name,
            schema=schema,
            consistency_level="Strong"
        )
        
        # 创建索引
        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="dense_vector",
            index_type="AUTOINDEX",
            metric_type="IP"
        )
        
        self.client.create_index(
            collection_name=self.collection_name,
            index_params=index_params
        )
    
    def load_data(self):
        """加载所有知识库数据"""
        data_dir = os.path.join(os.getcwd(), "data")
        
        # 加载DDL数据
        ddl_path = os.path.join(data_dir, "ddl_examples.json")
        if os.path.exists(ddl_path):
            with open(ddl_path, 'r', encoding='utf-8') as f:
                ddl_data = json.load(f)
            self._add_ddl_data(ddl_data)
        
        # 加载Q->SQL数据
        qsql_path = os.path.join(data_dir, "qsql_examples.json")
        if os.path.exists(qsql_path):
            with open(qsql_path, 'r', encoding='utf-8') as f:
                qsql_data = json.load(f)
            self._add_qsql_data(qsql_data)
        
        # 加载描述数据
        desc_path = os.path.join(data_dir, "db_descriptions.json")
        if os.path.exists(desc_path):
            with open(desc_path, 'r', encoding='utf-8') as f:
                desc_data = json.load(f)
            self._add_description_data(desc_data)
        
        # 加载集合到内存
        self.client.load_collection(collection_name=self.collection_name)
        print("知识库数据加载完成")
    
    def _add_ddl_data(self, data: List[Dict]):
        """添加DDL数据"""
        contents = []
        types = []
        
        for item in data:
            content = f"表名: {item.get('table_name', '')}\n"
            content += f"DDL: {item.get('ddl_statement', '')}\n"
            content += f"描述: {item.get('description', '')}"
            
            contents.append(content)
            types.append("ddl")
        
        self._insert_data(contents, types)
    
    def _add_qsql_data(self, data: List[Dict]):
        """添加Q->SQL数据"""
        contents = []
        types = []
        
        for item in data:
            content = f"问题: {item.get('question', '')}\n"
            content += f"SQL: {item.get('sql', '')}"
            
            contents.append(content)
            types.append("qsql")
        
        self._insert_data(contents, types)
    
    def _add_description_data(self, data: List[Dict]):
        """添加描述数据"""
        contents = []
        types = []
        
        for item in data:
            content = f"表名: {item.get('table_name', '')}\n"
            content += f"表描述: {item.get('table_description', '')}\n"
            
            columns = item.get('columns', [])
            if columns:
                content += "字段信息:\n"
                for col in columns:
                    content += f"  - {col.get('name', '')}: {col.get('description', '')} ({col.get('type', '')})\n"
            
            contents.append(content)
            types.append("description")
        
        self._insert_data(contents, types)
    
    def _insert_data(self, contents: List[str], types: List[str]):
        """插入数据"""
        if not contents:
            return
        
        # 生成嵌入
        embeddings = self.embedding_function(contents)
        
        # 构建插入数据，每一行是一个字典
        data_to_insert = []
        for i in range(len(contents)):
            data_to_insert.append({
                "content": contents[i],
                "type": types[i],
                "dense_vector": embeddings["dense"][i]
            })
        
        # 插入数据
        result = self.client.insert(
            collection_name=self.collection_name,
            data=data_to_insert
        )
    
    def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """搜索相关内容"""
        self.client.load_collection(collection_name=self.collection_name)
            
        query_embeddings = self.embedding_function([query])
        
        search_results = self.client.search(
            collection_name=self.collection_name,
            data=query_embeddings["dense"],
            anns_field="dense_vector",
            search_params={"metric_type": "IP"}, # 内积相似度
            limit=top_k,
            output_fields=["content", "type"]
        )
        
        results = []
        for hit in search_results[0]:
            results.append({
                "content": hit["entity"]["content"],
                "type": hit["entity"]["type"],
                "score": hit["distance"]
            })
        
        return results
    
    def cleanup(self):
        """清理资源"""
        try:
            self.client.drop_collection(self.collection_name)
        except:
            pass 


### SQL 生成模块

SQL生成模块负责将自然语言问题转换为SQL查询语句，并具备错误修复能力。

In [2]:
import os
from typing import List, Dict, Any
from langchain_deepseek import ChatDeepSeek
from langchain.messages import HumanMessage, SystemMessage
from dotenv import load_dotenv
load_dotenv()

class SimpleSQLGenerator:
    """简化的SQL生成器"""
    
    def __init__(self, api_key: str = None):
        self.llm = ChatDeepSeek(
            model="deepseek-chat",
            temperature=0,
            api_key=api_key or os.getenv("DEEPSEEK_API_KEY")
        )
    
    def generate_sql(self, user_query: str, knowledge_results: List[Dict[str, Any]]) -> str:
        """生成SQL语句"""
        # 构建上下文
        context = self._build_context(knowledge_results)
        
        # 构建提示
        prompt = f"""你是一个SQL专家。请根据以下信息将用户问题转换为SQL查询语句。

        数据库信息：
        {context}

        用户问题：{user_query}

        要求：
        1. 只返回SQL语句，不要包含任何解释
        2. 确保SQL语法正确
        3. 使用上下文中提供的表名和字段名
        4. 如果需要JOIN，请根据表结构进行合理关联

        SQL语句："""

        messages = [HumanMessage(content=prompt)]
        response = self.llm.invoke(messages)
        
        # 清理SQL语句
        sql = response.content.strip()
        if sql.startswith("```sql"):
            sql = sql[6:]
        if sql.startswith("```"):
            sql = sql[3:]
        if sql.endswith("```"):
            sql = sql[:-3]
        
        return sql.strip()
    
    def fix_sql(self, original_sql: str, error_message: str, knowledge_results: List[Dict[str, Any]]) -> str:
        """修复SQL语句"""
        context = self._build_context(knowledge_results)
        
        prompt = f"""请修复以下SQL语句的错误。

        数据库信息：
        {context}

        原始SQL：
        {original_sql}

        错误信息：
        {error_message}

        请返回修复后的SQL语句（只返回SQL，不要解释）："""

        messages = [HumanMessage(content=prompt)]
        response = self.llm.invoke(messages)
        
        # 清理SQL语句
        fixed_sql = response.content.strip()
        if fixed_sql.startswith("```sql"):
            fixed_sql = fixed_sql[6:]
        if fixed_sql.startswith("```"):
            fixed_sql = fixed_sql[3:]
        if fixed_sql.endswith("```"):
            fixed_sql = fixed_sql[:-3]
        
        return fixed_sql.strip()
    
    def _build_context(self, knowledge_results: List[Dict[str, Any]]) -> str:
        """构建上下文信息"""
        context = ""
        
        # 按类型分组
        ddl_info = []
        qsql_examples = []
        descriptions = []
        
        for result in knowledge_results:
            if result["type"] == "ddl":
                ddl_info.append(result["content"])
            elif result["type"] == "qsql":
                qsql_examples.append(result["content"])
            elif result["type"] == "description":
                descriptions.append(result["content"])
        
        # 构建上下文
        if ddl_info:
            context += "=== 表结构信息 ===\n"
            context += "\n".join(ddl_info) + "\n\n"
        
        if descriptions:
            context += "=== 表和字段描述 ===\n"
            context += "\n".join(descriptions) + "\n\n"
        
        if qsql_examples:
            context += "=== 查询示例 ===\n"
            context += "\n".join(qsql_examples) + "\n\n"
        
        return context 


## 代理模块

协调知识库检索、SQL生成和执行的完整流程。

In [3]:
import sqlite3
import os
from typing import Dict, Any, List, Tuple

class SimpleText2SQLAgent:
    """Text2SQL代理"""
    
    def __init__(self, milvus_uri: str = "http://localhost:19530", api_key: str = None):
        """初始化代理"""
        self.knowledge_base = SimpleKnowledgeBase(milvus_uri)
        self.sql_generator = SimpleSQLGenerator(api_key)
        self.db_path = None
        self.connection = None
        
        # 配置参数
        self.max_retry_count = 3
        self.top_k_retrieval = 5
        self.max_result_rows = 100
    
    def connect_database(self, db_path: str) -> bool:
        """连接SQLite数据库"""
        try:
            self.db_path = db_path
            self.connection = sqlite3.connect(db_path)
            print(f"成功连接到数据库: {db_path}")
            return True
        except Exception as e:
            print(f"数据库连接失败: {str(e)}")
            return False
    
    def load_knowledge_base(self):
        """加载知识库"""
        self.knowledge_base.load_data()
    
    def query(self, user_question: str) -> Dict[str, Any]:
        """执行Text2SQL查询"""
        if not self.connection:
            return {
                "success": False,
                "error": "数据库未连接",
                "sql": None,
                "results": None
            }
        
        print(f"\n=== 处理查询: {user_question} ===")
        
        # 1. 从知识库检索
        print("检索知识库...")
        knowledge_results = self.knowledge_base.search(user_question, self.top_k_retrieval)
        print(f"检索到 {len(knowledge_results)} 条相关信息")
        
        # 2. 生成SQL
        print("生成SQL...")
        sql = self.sql_generator.generate_sql(user_question, knowledge_results)
        print(f"生成的SQL: {sql}")
        
        # 3. 执行SQL（带重试）
        retry_count = 0
        while retry_count < self.max_retry_count:
            print(f"执行SQL (尝试 {retry_count + 1}/{self.max_retry_count})...")
            
            success, result = self._execute_sql(sql)
            
            if success:
                print("SQL执行成功!")
                return {
                    "success": True,
                    "error": None,
                    "sql": sql,
                    "results": result,
                    "retry_count": retry_count
                }
            else:
                print(f"SQL执行失败: {result}")
                
                if retry_count < self.max_retry_count - 1:
                    print("尝试修复SQL...")
                    sql = self.sql_generator.fix_sql(sql, result, knowledge_results)
                    print(f"修复后的SQL: {sql}")
                
                retry_count += 1
        
        return {
            "success": False,
            "error": f"超过最大重试次数 ({self.max_retry_count})",
            "sql": sql,
            "results": None,
            "retry_count": retry_count
        }
    
    def _execute_sql(self, sql: str) -> Tuple[bool, Any]:
        """执行SQL语句"""
        try:
            cursor = self.connection.cursor()
            
            # 添加LIMIT限制
            if sql.strip().upper().startswith('SELECT') and 'LIMIT' not in sql.upper():
                sql = f"{sql.rstrip(';')} LIMIT {self.max_result_rows}"
            
            cursor.execute(sql)
            
            if sql.strip().upper().startswith('SELECT'):
                # 查询语句
                columns = [desc[0] for desc in cursor.description]
                rows = cursor.fetchall()
                
                results = []
                for row in rows:
                    result_row = {}
                    for i, value in enumerate(row):
                        result_row[columns[i]] = value
                    results.append(result_row)
                
                cursor.close()
                return True, {
                    "columns": columns,
                    "rows": results,
                    "count": len(results)
                }
            else:
                # 非查询语句
                self.connection.commit()
                cursor.close()
                return True, "SQL执行成功"
        
        except Exception as e:
            return False, str(e)
    
    def add_example(self, question: str, sql: str):
        """添加新的Q->SQL示例"""
        # 简化版本：直接保存到文件
        data_dir = os.path.join(os.path.dirname(__file__), "data")
        qsql_path = os.path.join(data_dir, "qsql_examples.json")
        
        try:
            import json
            
            # 读取现有数据
            if os.path.exists(qsql_path):
                with open(qsql_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
            else:
                data = []
            
            # 添加新示例
            data.append({
                "question": question,
                "sql": sql,
                "database": "sqlite"
            })
            
            # 保存
            with open(qsql_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
            
            print(f"已添加新示例: {question}")
            
        except Exception as e:
            print(f"添加示例失败: {str(e)}")
    
    def get_table_info(self) -> List[Dict[str, Any]]:
        """获取数据库表信息"""
        if not self.connection:
            return []
        
        try:
            cursor = self.connection.cursor()
            
            # 获取所有表名
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
            tables = cursor.fetchall()
            
            table_info = []
            for table in tables:
                table_name = table[0]
                
                # 获取表结构
                cursor.execute(f"PRAGMA table_info({table_name})")
                columns = cursor.fetchall()
                
                table_info.append({
                    "table_name": table_name,
                    "columns": [
                        {
                            "name": col[1],
                            "type": col[2],
                            "nullable": not col[3],
                            "default": col[4],
                            "primary_key": bool(col[5])
                        }
                        for col in columns
                    ]
                })
            
            cursor.close()
            return table_info
            
        except Exception as e:
            print(f"获取表信息失败: {str(e)}")
            return []
    
    def cleanup(self):
        """清理资源"""
        if self.connection:
            self.connection.close()
            self.connection = None
            print("数据库连接已关闭")
        
        self.knowledge_base.cleanup()
        print("知识库已清理") 


## 主函数

In [4]:
import os
import sys
import sqlite3

def setup_demo():
    """设置演示环境"""
    print("=== Text2SQL框架演示 ===\n")
    
    # 检查API密钥
    api_key = os.getenv("DEEPSEEK_API_KEY")
    if not api_key:
        print("先设置DEEPSEEK_API_KEY环境变量")
        return None
    
    # 创建演示数据库
    print("创建演示数据库...")
    db_path = create_demo_database()
    
    # 初始化Text2SQL代理
    print("初始化Text2SQL代理...")
    agent = SimpleText2SQLAgent(api_key=api_key)
    
    # 连接数据库
    print("连接数据库...")
    if not agent.connect_database(db_path):
        print("数据库连接失败!")
        return None
    
    # 加载知识库
    print("加载知识库...")
    try:
        agent.load_knowledge_base()
        print("知识库加载成功!")
    except Exception as e:
        print(f"知识库加载失败: {str(e)}")
        return None
    
    return agent, db_path


def create_demo_database():
    """创建演示数据库"""
    db_path = "text2sql_demo.db"
    
    if os.path.exists(db_path):
        os.remove(db_path)
    
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # 创建用户表
    cursor.execute("""
        CREATE TABLE users (
            id INTEGER PRIMARY KEY,
            name TEXT NOT NULL,
            email TEXT UNIQUE,
            age INTEGER,
            city TEXT
        )
    """)
    
    # 创建产品表
    cursor.execute("""
        CREATE TABLE products (
            id INTEGER PRIMARY KEY,
            name TEXT NOT NULL,
            category TEXT,
            price REAL,
            stock INTEGER
        )
    """)
    
    # 创建订单表
    cursor.execute("""
        CREATE TABLE orders (
            id INTEGER PRIMARY KEY,
            user_id INTEGER,
            product_id INTEGER,
            quantity INTEGER,
            order_date TEXT,
            total_price REAL,
            FOREIGN KEY (user_id) REFERENCES users(id),
            FOREIGN KEY (product_id) REFERENCES products(id)
        )
    """)
    
    # 插入示例数据
    users_data = [
        (1, '张三', 'zhangsan@email.com', 25, '北京'),
        (2, '李四', 'lisi@email.com', 32, '上海'),
        (3, '王五', 'wangwu@email.com', 28, '广州'),
        (4, '赵六', 'zhaoliu@email.com', 35, '深圳'),
        (5, '陈七', 'chenqi@email.com', 29, '杭州'),
    ]
    
    products_data = [
        (1, 'iPhone 15', '电子产品', 7999.0, 50),
        (2, 'MacBook Pro', '电子产品', 12999.0, 20),
        (3, 'Nike运动鞋', '服装', 599.0, 100),
        (4, '办公椅', '家具', 899.0, 30),
        (5, '台灯', '家具', 199.0, 80),
        (6, 'iPad', '电子产品', 3999.0, 40),
        (7, 'Adidas外套', '服装', 399.0, 60),
    ]
    
    orders_data = [
        (1, 1, 1, 1, '2024-01-15', 7999.0),
        (2, 2, 3, 2, '2024-01-16', 1198.0),
        (3, 3, 5, 1, '2024-01-17', 199.0),
        (4, 1, 2, 1, '2024-01-18', 12999.0),
        (5, 4, 4, 1, '2024-01-19', 899.0),
        (6, 5, 6, 1, '2024-01-20', 3999.0),
        (7, 2, 7, 1, '2024-01-21', 399.0),
    ]
    
    cursor.executemany("INSERT INTO users VALUES (?, ?, ?, ?, ?)", users_data)
    cursor.executemany("INSERT INTO products VALUES (?, ?, ?, ?, ?)", products_data)
    cursor.executemany("INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?)", orders_data)
    
    conn.commit()
    conn.close()
    
    print(f"演示数据库已创建: {db_path}")
    return db_path


def run_demo_queries(agent):
    """运行演示查询"""
    demo_questions = [
        "查询所有用户的姓名和邮箱",
        "年龄大于30的用户有哪些",
        "哪些产品的库存少于50",
        "查询来自北京的用户的所有订单",
        "统计每个城市的用户数量",
        "查询价格在500-8000之间的产品"
    ]
    
    print("\n开始运行演示查询...\n")
    
    success_count = 0
    
    for i, question in enumerate(demo_questions, 1):
        print(f"问题 {i}: {question}")
        print("-" * 60)
        
        try:
            result = agent.query(question)
            
            if result["success"]:
                print(f"成功! SQL: {result['sql']}")
                
                if isinstance(result["results"], dict) and "rows" in result["results"]:
                    count = result["results"]["count"]
                    print(f"返回 {count} 行数据")
                    
                    # 显示前2行数据
                    if count > 0:
                        for j, row in enumerate(result["results"]["rows"][:2]):
                            row_str = " | ".join(f"{k}: {v}" for k, v in row.items())
                            print(f"  {j+1}. {row_str}")
                        
                        if count > 2:
                            print(f"  ... 还有 {count - 2} 行")
                else:
                    print(f"结果: {result['results']}")
                
                success_count += 1
                
            else:
                print(f"失败: {result['error']}")
                print(f"SQL: {result['sql']}")
                
        except Exception as e:
            print(f"执行错误: {str(e)}")
        
        print()
    
    # 输出统计
    total_count = len(demo_questions)


def cleanup(agent, db_path):
    """清理资源"""
    print("\n清理资源...")
    
    if agent:
        agent.cleanup()
    
    if os.path.exists(db_path):
        os.remove(db_path)
        print(f"已删除演示数据库: {db_path}")


In [5]:
def main():
    """主函数"""
    # 设置演示环境
    setup_result = setup_demo()
    
    if setup_result is None:
        return
    
    agent, db_path = setup_result
    
    try:
        # 运行演示查询
        run_demo_queries(agent)
        
    finally:
        # 清理资源
        cleanup(agent, db_path)


In [6]:
if __name__ == "__main__":
    main() 


=== Text2SQL框架演示 ===

创建演示数据库...
演示数据库已创建: text2sql_demo.db
初始化Text2SQL代理...



You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


连接数据库...
成功连接到数据库: text2sql_demo.db
加载知识库...
知识库数据加载完成
知识库加载成功!

开始运行演示查询...

问题 1: 查询所有用户的姓名和邮箱
------------------------------------------------------------

=== 处理查询: 查询所有用户的姓名和邮箱 ===
检索知识库...
检索到 5 条相关信息
生成SQL...
生成的SQL: SELECT name, email FROM users;
执行SQL (尝试 1/3)...
SQL执行成功!
成功! SQL: SELECT name, email FROM users;
返回 5 行数据
  1. name: 张三 | email: zhangsan@email.com
  2. name: 李四 | email: lisi@email.com
  ... 还有 3 行

问题 2: 年龄大于30的用户有哪些
------------------------------------------------------------

=== 处理查询: 年龄大于30的用户有哪些 ===
检索知识库...
检索到 5 条相关信息
生成SQL...
生成的SQL: SELECT * FROM users WHERE age > 30;
执行SQL (尝试 1/3)...
SQL执行成功!
成功! SQL: SELECT * FROM users WHERE age > 30;
返回 2 行数据
  1. id: 2 | name: 李四 | email: lisi@email.com | age: 32 | city: 上海
  2. id: 4 | name: 赵六 | email: zhaoliu@email.com | age: 35 | city: 深圳

问题 3: 哪些产品的库存少于50
------------------------------------------------------------

=== 处理查询: 哪些产品的库存少于50 ===
检索知识库...
检索到 5 条相关信息
生成SQL...
生成的SQL: SELECT * FROM products WHERE s