# 使用指南
## 项目简介
一个基于 MySQL + 大语言模型（LLM）的 SQL 自动生成与执行工具，支持：

- **数据库模式自动读取**：动态获取当前 MySQL 实例中所有表及字段信息
- **自然语言转 SQL**：通过深度学习模型（如 OpenAI deepseek-chat）将用户的中文/英文查询转换为标准 SQL
- **SQL 执行与结果展示**：在 MySQL 上执行生成的 SQL，并返回 Pandas DataFrame 形式的查询结果
- **标签化 SQL 提取**：支持在输入中使用 `<sql>…</sql>` 包裹多句 SQL 片段，自动提取并执行

# Setup
## 1.拉取资源

In [25]:
%pip install -q openai pymysql pandas 

Note: you may need to restart the kernel to use updated packages.


## 2.初始化MySql连接

In [38]:
import pymysql
from pymysql.cursors import DictCursor
import pandas as pd
import re

class MySQLClient:
    """
    封装了 MySQL 连接的初始化、关闭，及获取 schema 信息的方法。
    """
    def __init__(
        self,
        host: str,
        port: int,
        user: str,
        password: str,
        database: str,
        autocommit: bool = True,
        charset: str = 'utf8mb4'
    ):
        self.connection = pymysql.connect(
            host=host,
            port=port,
            user=user,
            password=password,
            database=database,
            charset=charset,
            cursorclass=DictCursor,
            autocommit=autocommit
        )
        print(f"MySQL 连接已初始化: {user}@{host}:{port}/{database}")

    def get_schema_info(self) -> list[str]:
        """
        列出当前数据库中所有表及其字段信息，返回格式化后的字符串列表。
        """
        schema_info: list[str] = []
        with self.connection.cursor() as cursor:
            # 列出所有表
            cursor.execute("SHOW TABLES;")
            tables = [list(row.values())[0] for row in cursor.fetchall()]
            # 遍历每个表，获取字段信息
            for table_name in tables:
                cursor.execute(f"SHOW COLUMNS FROM `{table_name}`;")
                columns = cursor.fetchall()
                # 格式化输出
                table_info = f"Table: {table_name}\n"
                table_info += "\n".join(
                    f"  - {col['Field']} ({col['Type']})"
                    for col in columns
                )
                schema_info.append(table_info)
        return "\n\n".join(schema_info)

    def close(self):
        """关闭数据库连接"""
        if self.connection:
            self.connection.close()
            print("MySQL 连接已关闭")

    def run_sql(self, sql: str) -> pd.DataFrame:
            """
            用原生 PyMySQL 游标执行任意 SQL，并把结果封装到 DataFrame。
            支持传入带 <sql> 标签的文本。
            """
            # 1) 提取标签内 SQL，失败则用原文
            real_sql = self.extract_sql(sql) or sql

            # 2) 执行并 fetchall
            with self.connection.cursor() as cursor:
                cursor.execute(real_sql)
                rows = cursor.fetchall()  # List[dict]

            # 3) pandas DataFrame
            df = pd.DataFrame(rows)
            return df
    
    def extract_sql(self,text: str) -> str | None:
        """
        从给定的字符串中提取第一个 <sql>...</sql> 标签之间的内容并返回。
        如果没找到标签，则返回 None。
        """
        pattern = re.compile(r'<sql>(.*?)</sql>', re.IGNORECASE | re.DOTALL)
        match = pattern.search(text)
        if match:
            # strip() 去掉开头结尾的多余空白或换行
            return match.group(1).strip()
        return None


MySQL 连接已初始化: root@localhost:3306/bjxoj
Table: check_in
  - id (bigint)
  - checkInTime (datetime)
  - userId (bigint)
  - createTime (datetime)
  - updateTime (datetime)
  - isDelete (tinyint)

Table: comment
  - id (bigint unsigned)
  - postId (bigint unsigned)
  - userId (bigint unsigned)
  - content (text)
  - parentId (bigint unsigned)
  - rootId (bigint unsigned)
  - status (enum('pending','approved','deleted'))
  - isTop (tinyint unsigned)
  - likeCount (int unsigned)
  - reportCount (int unsigned)
  - ipAddress (varchar(45))
  - userAgent (varchar(500))
  - createTime (timestamp)
  - updateTime (timestamp)

Table: post
  - id (bigint)
  - title (varchar(512))
  - content (text)
  - tags (varchar(1024))
  - thumbNum (int unsigned)
  - favourNum (int unsigned)
  - viewsNum (int unsigned)
  - userId (bigint)
  - createTime (datetime)
  - updateTime (datetime)
  - isDelete (tinyint)
  - questionId (bigint)
  - type (varchar(255))
  - description (text)
  - coverUrl (varchar(255))



## 3.文本生成SQL提示词

In [21]:
def generate_prompt(schema, query):
    return f"""
        You are an AI assistant that converts natural language queries into SQL.
        Given the following SQL database schema:

        <schema>
        {schema}
        </schema>

        Convert the following natural language query into SQL:
        <query>
        {query}
        </query>

        Provide only the SQL query in your response, enclosed within <sql> tags.
    """

MySQL 连接已初始化: root@localhost:3306/bjxoj
当前数据库中的表： ['check_in', 'comment', 'post', 'post_favour', 'post_thumb', 'post_view', 'question', 'question_submit', 'user', 'user_copy1', 'user_copy2']
['Table: check_in\n  - id (bigint)\n  - checkInTime (datetime)\n  - userId (bigint)\n  - createTime (datetime)\n  - updateTime (datetime)\n  - isDelete (tinyint)', "Table: comment\n  - id (bigint unsigned)\n  - postId (bigint unsigned)\n  - userId (bigint unsigned)\n  - content (text)\n  - parentId (bigint unsigned)\n  - rootId (bigint unsigned)\n  - status (enum('pending','approved','deleted'))\n  - isTop (tinyint unsigned)\n  - likeCount (int unsigned)\n  - reportCount (int unsigned)\n  - ipAddress (varchar(45))\n  - userAgent (varchar(500))\n  - createTime (timestamp)\n  - updateTime (timestamp)", 'Table: post\n  - id (bigint)\n  - title (varchar(512))\n  - content (text)\n  - tags (varchar(1024))\n  - thumbNum (int unsigned)\n  - favourNum (int unsigned)\n  - viewsNum (int unsigned)\n  - userId

## 3.大模型接口

In [24]:
from openai import OpenAI

apiKey = "xxxxxxxxxxxxxxxxxxxxxxxxxx"
base_url = "https://api.deepseek.com"
client = OpenAI(api_key=apiKey, base_url=base_url)

def generate_sql(prompt): 
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=1000
    )
    return response.choices[0].message.content


MySQL 连接已初始化: root@localhost:3306/bjxoj
当前数据库中的表： ['check_in', 'comment', 'post', 'post_favour', 'post_thumb', 'post_view', 'question', 'question_submit', 'user', 'user_copy1', 'user_copy2']
['Table: check_in\n  - id (bigint)\n  - checkInTime (datetime)\n  - userId (bigint)\n  - createTime (datetime)\n  - updateTime (datetime)\n  - isDelete (tinyint)', "Table: comment\n  - id (bigint unsigned)\n  - postId (bigint unsigned)\n  - userId (bigint unsigned)\n  - content (text)\n  - parentId (bigint unsigned)\n  - rootId (bigint unsigned)\n  - status (enum('pending','approved','deleted'))\n  - isTop (tinyint unsigned)\n  - likeCount (int unsigned)\n  - reportCount (int unsigned)\n  - ipAddress (varchar(45))\n  - userAgent (varchar(500))\n  - createTime (timestamp)\n  - updateTime (timestamp)", 'Table: post\n  - id (bigint)\n  - title (varchar(512))\n  - content (text)\n  - tags (varchar(1024))\n  - thumbNum (int unsigned)\n  - favourNum (int unsigned)\n  - viewsNum (int unsigned)\n  - userId

## 4.执行SQL

In [41]:
db = MySQLClient(
    host='localhost',
    port=3306,
    user='root',
    password='xxxxxxxxxxxxxx',
    database='xxxxxxxxxxxxx'
)
schema = db.get_schema_info()
user_query = "查看所有用户的提交记录信息（只显示重要信息，不要显示所有信息）"
prompt = generate_prompt(schema, user_query)
sql = generate_sql(prompt)
print(sql)
result = db.run_sql(sql)
print("Query result:")
display(result)


MySQL 连接已初始化: root@localhost:3306/bjxoj
<sql>
SELECT 
    u.userName AS '用户名',
    q.title AS '题目名称',
    qs.language AS '编程语言',
    qs.status AS '提交状态',
    qs.createTime AS '提交时间'
FROM 
    question_submit qs
JOIN 
    user u ON qs.userId = u.id
JOIN 
    question q ON qs.questionId = q.id
WHERE 
    qs.isDelete = 0
ORDER BY 
    qs.createTime DESC;
</sql>
Query result:


Unnamed: 0,用户名,题目名称,编程语言,提交状态,提交时间
0,我不是人,逃离迷宫,cpp,2,2025-04-01 23:21:42
1,bjxoj,寻宝之旅,cpp,2,2025-03-29 14:56:10
2,bjxoj,逃离迷宫,cpp,2,2025-03-29 14:54:09
3,我不是人,a+b,cpp,2,2025-03-19 14:48:54
4,bjxoj,二叉树的节点深度统计,cpp,2,2025-03-12 23:26:22
...,...,...,...,...,...
96,bjxoj,a+b,java,3,2024-05-10 13:26:09
97,bjxoj,a+b,java,3,2024-05-10 13:23:29
98,bjxoj,a+b,java,3,2024-05-10 13:20:54
99,bjxoj,两数之和,cpp,2,2024-05-05 16:25:30
