## 12. 复杂SQL测试：挑战性场景对比

# SQL查询生成模型微调实验

本notebook演示如何微调Qwen2.5-1.5B模型，使其学会SQL查询生成能力。

## 1. 环境准备

In [1]:
# 安装必要的包
!pip install torch transformers datasets peft bitsandbytes accelerate scipy sentencepiece einops -q

In [1]:
import torch
import transformers
import peft

print(f"PyTorch版本: {torch.__version__}")
print(f"Transformers版本: {transformers.__version__}")
print(f"PEFT版本: {peft.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch版本: 2.8.0+cu128
Transformers版本: 4.55.0
PEFT版本: 0.15.2
CUDA可用: True
GPU: NVIDIA GeForce RTX 5080


## 2. 数据准备

In [2]:
from datasets import load_dataset

# 加载数据集
dataset = load_dataset("b-mc2/sql-create-context", split="train[:1000]")
print(f"数据集大小: {len(dataset)}")
print("\n数据示例：")
print(dataset[0])

数据集大小: 1000

数据示例：
{'answer': 'SELECT COUNT(*) FROM head WHERE age > 56', 'question': 'How many heads of the departments are older than 56 ?', 'context': 'CREATE TABLE head (age INTEGER)'}


In [3]:
# 格式化数据为指令格式
def format_instruction(example):
    instruction = f"""### 任务：根据数据库结构和问题生成SQL查询

### 数据库结构：
{example['context']}

### 问题：
{example['question']}

### SQL查询：
{example['answer']}"""
    return {"text": instruction}

# 应用格式化
dataset = dataset.map(format_instruction)

# 查看格式化后的数据
print(dataset[0]['text'][:500])

### 任务：根据数据库结构和问题生成SQL查询

### 数据库结构：
CREATE TABLE head (age INTEGER)

### 问题：
How many heads of the departments are older than 56 ?

### SQL查询：
SELECT COUNT(*) FROM head WHERE age > 56


## 3. 模型配置

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# 4-bit量化配置
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# 加载模型
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

In [5]:
# 准备模型进行k-bit训练
model = prepare_model_for_kbit_training(model)

# LoRA配置
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# 应用LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 18,464,768 || all params: 1,562,179,072 || trainable%: 1.1820


## 4. 训练前测试

In [6]:
# 测试未微调模型的SQL生成能力
test_prompt = """### 任务：根据数据库结构和问题生成SQL查询

### 数据库结构：
CREATE TABLE users (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    age INT,
    city VARCHAR(50)
);

### 问题：
查找年龄大于25岁的用户姓名

### SQL查询：
"""

inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        temperature=0.1,
        do_sample=True
    )

print("训练前的输出：")
print(tokenizer.decode(outputs[0], skip_special_tokens=True).split("### SQL查询：")[-1])

训练前的输出：

```sql
SELECT name FROM users WHERE age > 25;
```

### 解释：
- `SELECT name`：选择用户的名字。
- `FROM users`：从users表中获取数据。
- `WHERE age > 2


## 5. 数据处理

In [7]:
# 划分训练集和验证集
dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(eval_dataset)}")

训练集大小: 900
验证集大小: 100


In [8]:
# 分词函数
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=512,
    )

# 处理数据集
train_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names
)

eval_dataset = eval_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=eval_dataset.column_names
)

Map: 100%|██████████| 100/100 [00:00<00:00, 7714.09 examples/s]


## 6. 训练配置

In [9]:
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

OUTPUT_DIR = "./qwen2.5-sql-finetuned"

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=100,
    logging_steps=50,
    save_steps=200,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=True,
    gradient_checkpointing=True,
    report_to="none",
)

# 数据收集器
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

## 7. 开始训练

In [None]:
# 初始化训练器
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# 开始训练
trainer.train()

  trainer = Trainer(
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss


In [None]:
# 保存模型
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"模型已保存到: {OUTPUT_DIR}")

## 8. 测试微调后的模型

In [10]:
# 加载微调后的模型
from peft import PeftModel

# 重新加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

# 加载LoRA权重
finetuned_model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
finetuned_model = finetuned_model.merge_and_unload()

In [11]:
# 测试SQL生成
test_cases = [
    {
        "context": """CREATE TABLE products (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    price DECIMAL(10,2),
    category VARCHAR(50)
);""",
        "question": "查找价格超过100的所有产品名称"
    },
    {
        "context": """CREATE TABLE employees (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    department VARCHAR(50),
    salary INT
);""",
        "question": "计算每个部门的平均工资"
    }
]

for i, test in enumerate(test_cases, 1):
    prompt = f"""### 任务：根据数据库结构和问题生成SQL查询

### 数据库结构：
{test['context']}

### 问题：
{test['question']}

### SQL查询：
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(finetuned_model.device)
    
    with torch.no_grad():
        outputs = finetuned_model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    sql = response.split("### SQL查询：")[-1].strip()
    
    print(f"\n测试案例 {i}:")
    print(f"问题: {test['question']}")
    print(f"生成的SQL: {sql}")
    print("-" * 50)


测试案例 1:
问题: 查找价格超过100的所有产品名称
生成的SQL: SELECT name FROM products WHERE price > 100;
--------------------------------------------------

测试案例 2:
问题: 计算每个部门的平均工资
生成的SQL: SELECT department, AVG(salary) FROM employees GROUP BY department;
--------------------------------------------------


## 9. 性能评估

In [12]:
# 简单的准确性评估
def evaluate_model(model, tokenizer, test_dataset, num_samples=20):
    correct = 0
    
    for i in range(min(num_samples, len(test_dataset))):
        example = test_dataset[i]
        
        prompt = f"""### 任务：根据数据库结构和问题生成SQL查询

### 数据库结构：
{example['context']}

### 问题：
{example['question']}

### SQL查询：
"""
        
        inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=150,
                temperature=0.1,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = response.split("### SQL查询：")[-1].strip().lower()
        expected_sql = example['answer'].strip().lower()
        
        # 简单的相似度检查
        if "select" in generated_sql and "from" in generated_sql:
            correct += 0.5  # 至少生成了有效的SQL结构
            
            # 检查关键词匹配
            expected_keywords = set(expected_sql.split())
            generated_keywords = set(generated_sql.split())
            overlap = len(expected_keywords & generated_keywords) / len(expected_keywords)
            
            if overlap > 0.6:
                correct += 0.5
    
    accuracy = (correct / num_samples) * 100
    return accuracy

# 评估模型
test_data = load_dataset("b-mc2/sql-create-context", split="train[1000:1020]")
accuracy = evaluate_model(finetuned_model, tokenizer, test_data)
print(f"模型准确率（简单评估）: {accuracy:.2f}%")

模型准确率（简单评估）: 87.50%


## 10. 总结

通过这个实验，我们成功地：
1. 使用LoRA和4-bit量化技术高效微调了Qwen2.5-1.5B模型
2. 让模型学会了SQL查询生成能力
3. 整个训练过程只需要约6-8GB显存

### 改进方向：
- 增加训练数据量
- 调整LoRA参数（r值、目标模块）
- 优化提示词格式
- 添加更复杂的SQL场景（JOIN、子查询等）

## 11. 对比测试：20个示例比较原始模型vs微调模型

In [13]:
# 20个测试示例 - 对比原始模型和微调模型
test_examples = [
    {
        "context": """CREATE TABLE employees (
            id INT PRIMARY KEY,
            name VARCHAR(100),
            department VARCHAR(50),
            salary DECIMAL(10,2),
            hire_date DATE
        );""",
        "question": "查找所有IT部门的员工姓名"
    },
    {
        "context": """CREATE TABLE products (
            product_id INT PRIMARY KEY,
            product_name VARCHAR(100),
            category VARCHAR(50),
            price DECIMAL(10,2),
            stock INT
        );""",
        "question": "找出库存为0的产品名称和类别"
    },
    {
        "context": """CREATE TABLE orders (
            order_id INT PRIMARY KEY,
            customer_id INT,
            order_date DATE,
            total_amount DECIMAL(10,2)
        );""",
        "question": "计算所有订单的总金额"
    },
    {
        "context": """CREATE TABLE students (
            student_id INT PRIMARY KEY,
            name VARCHAR(100),
            age INT,
            grade VARCHAR(10),
            gpa DECIMAL(3,2)
        );""",
        "question": "找出GPA大于3.5且年龄小于20的学生"
    },
    {
        "context": """CREATE TABLE books (
            book_id INT PRIMARY KEY,
            title VARCHAR(200),
            author VARCHAR(100),
            year INT,
            price DECIMAL(10,2)
        );""",
        "question": "查询2020年后出版的书籍标题和作者"
    },
    {
        "context": """CREATE TABLE customers (
            customer_id INT PRIMARY KEY,
            name VARCHAR(100),
            email VARCHAR(100),
            city VARCHAR(50),
            country VARCHAR(50)
        );""",
        "question": "统计每个国家的客户数量"
    },
    {
        "context": """CREATE TABLE transactions (
            trans_id INT PRIMARY KEY,
            account_id INT,
            amount DECIMAL(10,2),
            trans_type VARCHAR(20),
            trans_date DATE
        );""",
        "question": "找出所有取款(withdraw)交易的总金额"
    },
    {
        "context": """CREATE TABLE inventory (
            item_id INT PRIMARY KEY,
            item_name VARCHAR(100),
            quantity INT,
            warehouse VARCHAR(50),
            last_updated DATE
        );""",
        "question": "查找北京仓库中数量少于50的物品"
    },
    {
        "context": """CREATE TABLE flights (
            flight_id VARCHAR(10) PRIMARY KEY,
            departure_city VARCHAR(50),
            arrival_city VARCHAR(50),
            departure_time DATETIME,
            price DECIMAL(10,2)
        );""",
        "question": "找出从上海出发的所有航班"
    },
    {
        "context": """CREATE TABLE sales (
            sale_id INT PRIMARY KEY,
            product_id INT,
            quantity INT,
            sale_date DATE,
            revenue DECIMAL(10,2)
        );""",
        "question": "计算每个产品的总销售数量"
    },
    {
        "context": """CREATE TABLE users (
            user_id INT PRIMARY KEY,
            username VARCHAR(50),
            email VARCHAR(100),
            created_at DATE,
            is_active BOOLEAN
        );""",
        "question": "查找所有未激活的用户邮箱"
    },
    {
        "context": """CREATE TABLE courses (
            course_id VARCHAR(10) PRIMARY KEY,
            course_name VARCHAR(100),
            credits INT,
            department VARCHAR(50),
            instructor VARCHAR(100)
        );""",
        "question": "找出学分大于3的计算机科学课程"
    },
    {
        "context": """CREATE TABLE movies (
            movie_id INT PRIMARY KEY,
            title VARCHAR(200),
            director VARCHAR(100),
            year INT,
            rating DECIMAL(3,1)
        );""",
        "question": "查询评分高于8.0的电影标题和导演"
    },
    {
        "context": """CREATE TABLE employees (
            emp_id INT PRIMARY KEY,
            name VARCHAR(100),
            department VARCHAR(50),
            salary INT,
            bonus INT
        );""",
        "question": "计算每个部门的平均工资"
    },
    {
        "context": """CREATE TABLE projects (
            project_id INT PRIMARY KEY,
            project_name VARCHAR(100),
            start_date DATE,
            end_date DATE,
            budget DECIMAL(12,2),
            status VARCHAR(20)
        );""",
        "question": "找出状态为'进行中'且预算超过100万的项目"
    },
    {
        "context": """CREATE TABLE vehicles (
            vehicle_id INT PRIMARY KEY,
            brand VARCHAR(50),
            model VARCHAR(50),
            year INT,
            price DECIMAL(10,2),
            type VARCHAR(20)
        );""",
        "question": "查询2022年后生产的SUV车型"
    },
    {
        "context": """CREATE TABLE restaurants (
            restaurant_id INT PRIMARY KEY,
            name VARCHAR(100),
            cuisine_type VARCHAR(50),
            city VARCHAR(50),
            rating DECIMAL(2,1)
        );""",
        "question": "找出北京所有评分超过4.5的中餐厅"
    },
    {
        "context": """CREATE TABLE articles (
            article_id INT PRIMARY KEY,
            title VARCHAR(200),
            author VARCHAR(100),
            publish_date DATE,
            views INT,
            category VARCHAR(50)
        );""",
        "question": "查询浏览量最高的前10篇文章"
    },
    {
        "context": """CREATE TABLE hotels (
            hotel_id INT PRIMARY KEY,
            hotel_name VARCHAR(100),
            city VARCHAR(50),
            stars INT,
            price_per_night DECIMAL(10,2)
        );""",
        "question": "找出五星级酒店中每晚价格低于500的酒店"
    },
    {
        "context": """CREATE TABLE stock_prices (
            stock_id VARCHAR(10),
            date DATE,
            open_price DECIMAL(10,2),
            close_price DECIMAL(10,2),
            volume INT,
            PRIMARY KEY (stock_id, date)
        );""",
        "question": "计算每只股票的平均收盘价"
    }
]

print("="*80)
print("对比测试：原始模型 vs 微调模型")
print("="*80)

# 测试函数
def test_and_compare(base_model, finetuned_model, tokenizer, example):
    prompt = f"""### 任务：根据数据库结构和问题生成SQL查询

### 数据库结构：
{example['context']}

### 问题：
{example['question']}

### SQL查询：
"""
    
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
    
    # 原始模型生成
    inputs_base = {k: v.to(base_model.device) for k, v in inputs.items()}
    with torch.no_grad():
        base_outputs = base_model.generate(
            **inputs_base,
            max_new_tokens=100,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    base_response = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    base_sql = base_response.split("### SQL查询：")[-1].strip()
    if ";" in base_sql:
        base_sql = base_sql.split(";")[0] + ";"
    
    # 微调模型生成
    inputs_fine = {k: v.to(finetuned_model.device) for k, v in inputs.items()}
    with torch.no_grad():
        fine_outputs = finetuned_model.generate(
            **inputs_fine,
            max_new_tokens=100,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    fine_response = tokenizer.decode(fine_outputs[0], skip_special_tokens=True)
    fine_sql = fine_response.split("### SQL查询：")[-1].strip()
    if ";" in fine_sql:
        fine_sql = fine_sql.split(";")[0] + ";"
    
    return base_sql, fine_sql

# 加载原始模型
print("加载原始模型...")
original_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

# 运行对比测试
results = []
for i, example in enumerate(test_examples, 1):
    print(f"\n示例 {i}: {example['question']}")
    print("-" * 60)
    
    base_sql, fine_sql = test_and_compare(original_model, finetuned_model, tokenizer, example)
    
    print(f"原始模型: {base_sql[:100]}...")
    print(f"微调模型: {fine_sql[:100]}...")
    
    # 简单评分
    base_score = 1 if "SELECT" in base_sql.upper() and "FROM" in base_sql.upper() else 0
    fine_score = 1 if "SELECT" in fine_sql.upper() and "FROM" in fine_sql.upper() else 0
    
    results.append({
        'question': example['question'],
        'base_valid': base_score,
        'fine_valid': fine_score,
        'base_sql': base_sql,
        'fine_sql': fine_sql
    })

# 统计结果
print("\n" + "="*80)
print("总体统计")
print("="*80)

base_valid_count = sum(r['base_valid'] for r in results)
fine_valid_count = sum(r['fine_valid'] for r in results)

print(f"原始模型生成有效SQL: {base_valid_count}/20 ({base_valid_count*5}%)")
print(f"微调模型生成有效SQL: {fine_valid_count}/20 ({fine_valid_count*5}%)")
print(f"提升率: {((fine_valid_count - base_valid_count) / max(base_valid_count, 1)) * 100:.1f}%")

# 显示改进最明显的例子
print("\n" + "="*80)
print("改进最明显的示例")
print("="*80)

improved = [r for r in results if r['fine_valid'] > r['base_valid']][:3]
for r in improved:
    print(f"\n问题: {r['question']}")
    print(f"原始: {r['base_sql'][:80]}...")
    print(f"微调: {r['fine_sql'][:80]}...")

对比测试：原始模型 vs 微调模型
加载原始模型...

示例 1: 查找所有IT部门的员工姓名
------------------------------------------------------------
原始模型: ```sql
SELECT name FROM employees WHERE department = 'IT';...
微调模型: SELECT name FROM employees WHERE department = 'IT'...

示例 2: 找出库存为0的产品名称和类别
------------------------------------------------------------
原始模型: ```sql
SELECT product_name, category FROM products WHERE stock = 0;...
微调模型: SELECT product_name, category FROM products WHERE stock = 0...

示例 3: 计算所有订单的总金额
------------------------------------------------------------
原始模型: ```sql
SELECT SUM(total_amount) AS total_order_amount 
FROM orders;...
微调模型: SELECT SUM(total_amount) FROM orders;...

示例 4: 找出GPA大于3.5且年龄小于20的学生
------------------------------------------------------------
原始模型: ```sql
SELECT * FROM students WHERE gpa > 3.5 AND age < 20;...
微调模型: SELECT * FROM students WHERE gpa > 3.5 AND age < 20;...

示例 5: 查询2020年后出版的书籍标题和作者
------------------------------------------------------------
原始模型: ```sql
SELECT tit

In [14]:
# 更复杂的测试场景 - 测试微调的真正效果
complex_test_examples = [
    {
        "context": """CREATE TABLE orders (
            order_id INT PRIMARY KEY,
            customer_id INT,
            order_date DATE,
            total_amount DECIMAL(10,2)
        );
        CREATE TABLE order_items (
            item_id INT PRIMARY KEY,
            order_id INT,
            product_id INT,
            quantity INT,
            price DECIMAL(10,2),
            FOREIGN KEY (order_id) REFERENCES orders(order_id)
        );""",
        "question": "找出订单金额超过1000且包含超过5个商品的订单"
    },
    {
        "context": """CREATE TABLE employees (
            emp_id INT PRIMARY KEY,
            name VARCHAR(100),
            department VARCHAR(50),
            salary DECIMAL(10,2),
            manager_id INT
        );""",
        "question": "找出工资比他们经理高的员工"
    },
    {
        "context": """CREATE TABLE sales (
            sale_id INT PRIMARY KEY,
            product_id INT,
            sale_date DATE,
            quantity INT,
            revenue DECIMAL(10,2)
        );""",
        "question": "计算最近30天的滚动平均销售额"
    },
    {
        "context": """CREATE TABLE students (
            student_id INT PRIMARY KEY,
            name VARCHAR(100),
            major VARCHAR(50)
        );
        CREATE TABLE courses (
            course_id INT PRIMARY KEY,
            course_name VARCHAR(100),
            credits INT
        );
        CREATE TABLE enrollments (
            student_id INT,
            course_id INT,
            grade VARCHAR(2),
            FOREIGN KEY (student_id) REFERENCES students(student_id),
            FOREIGN KEY (course_id) REFERENCES courses(course_id)
        );""",
        "question": "找出选修了所有计算机科学课程的学生"
    },
    {
        "context": """CREATE TABLE products (
            product_id INT PRIMARY KEY,
            name VARCHAR(100),
            category VARCHAR(50),
            price DECIMAL(10,2)
        );""",
        "question": "找出每个类别中价格排名前3的产品"
    },
    {
        "context": """CREATE TABLE transactions (
            trans_id INT PRIMARY KEY,
            account_id INT,
            trans_date DATETIME,
            amount DECIMAL(10,2),
            balance DECIMAL(10,2)
        );""",
        "question": "检测可能的欺诈交易（1小时内超过3笔且总额超过10000）"
    },
    {
        "context": """CREATE TABLE posts (
            post_id INT PRIMARY KEY,
            user_id INT,
            post_date DATE,
            likes INT,
            shares INT
        );
        CREATE TABLE comments (
            comment_id INT PRIMARY KEY,
            post_id INT,
            user_id INT,
            comment_text TEXT,
            FOREIGN KEY (post_id) REFERENCES posts(post_id)
        );""",
        "question": "找出互动率最高的帖子（点赞+分享+评论数）"
    },
    {
        "context": """CREATE TABLE inventory (
            product_id INT PRIMARY KEY,
            warehouse_id INT,
            quantity INT,
            last_restocked DATE
        );""",
        "question": "找出需要补货的产品（库存低于过去3个月平均销量的20%）"
    },
    {
        "context": """CREATE TABLE users (
            user_id INT PRIMARY KEY,
            registration_date DATE,
            last_login DATE,
            subscription_type VARCHAR(20)
        );""",
        "question": "计算每月的用户留存率"
    },
    {
        "context": """CREATE TABLE flights (
            flight_id INT PRIMARY KEY,
            departure_city VARCHAR(50),
            arrival_city VARCHAR(50),
            departure_time DATETIME,
            arrival_time DATETIME,
            price DECIMAL(10,2)
        );""",
        "question": "找出所有可能的中转航班组合（A到C经过B）"
    }
]

print("="*80)
print("复杂SQL测试：原始模型 vs 微调模型")
print("="*80)

# 详细的评分函数
def detailed_score(sql):
    score = 0
    sql_upper = sql.upper()
    
    # 基础SQL结构 (1分)
    if "SELECT" in sql_upper and "FROM" in sql_upper:
        score += 1
    
    # JOIN操作 (1分)
    if any(join in sql_upper for join in ["JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"]):
        score += 1
    
    # 子查询 (1分)
    if sql.count("(") > 1 and "SELECT" in sql_upper[sql_upper.find("(")+1:]:
        score += 1
    
    # 聚合函数 (0.5分)
    if any(func in sql_upper for func in ["COUNT", "SUM", "AVG", "MAX", "MIN"]):
        score += 0.5
    
    # GROUP BY (0.5分)
    if "GROUP BY" in sql_upper:
        score += 0.5
    
    # HAVING子句 (0.5分)
    if "HAVING" in sql_upper:
        score += 0.5
    
    # 窗口函数 (1分)
    if any(func in sql_upper for func in ["ROW_NUMBER", "RANK", "DENSE_RANK", "OVER"]):
        score += 1
    
    # CTE或WITH子句 (0.5分)
    if "WITH" in sql_upper and "AS" in sql_upper:
        score += 0.5
    
    return score

# 运行复杂测试
complex_results = []
for i, example in enumerate(complex_test_examples, 1):
    print(f"\n测试 {i}: {example['question']}")
    print("-" * 60)
    
    prompt = f"""### 任务：根据数据库结构和问题生成SQL查询

### 数据库结构：
{example['context']}

### 问题：
{example['question']}

### SQL查询：
"""
    
    # 测试两个模型
    base_sql, fine_sql = test_and_compare(original_model, finetuned_model, tokenizer, example)
    
    # 详细评分
    base_score = detailed_score(base_sql)
    fine_score = detailed_score(fine_sql)
    
    print(f"原始模型 (得分:{base_score:.1f}): {base_sql[:150]}...")
    print(f"微调模型 (得分:{fine_score:.1f}): {fine_sql[:150]}...")
    
    complex_results.append({
        'question': example['question'],
        'base_score': base_score,
        'fine_score': fine_score,
        'improvement': fine_score - base_score
    })

# 统计复杂测试结果
print("\n" + "="*80)
print("复杂SQL测试统计")
print("="*80)

avg_base_score = sum(r['base_score'] for r in complex_results) / len(complex_results)
avg_fine_score = sum(r['fine_score'] for r in complex_results) / len(complex_results)

print(f"原始模型平均得分: {avg_base_score:.2f}/6")
print(f"微调模型平均得分: {avg_fine_score:.2f}/6")
print(f"平均提升: {avg_fine_score - avg_base_score:.2f}分")

# 显示提升最大的例子
print("\n" + "="*80)
print("提升最明显的复杂查询")
print("="*80)

complex_results.sort(key=lambda x: x['improvement'], reverse=True)
for r in complex_results[:3]:
    if r['improvement'] > 0:
        print(f"\n问题: {r['question']}")
        print(f"原始模型得分: {r['base_score']:.1f}")
        print(f"微调模型得分: {r['fine_score']:.1f}")
        print(f"提升: +{r['improvement']:.1f}分")

复杂SQL测试：原始模型 vs 微调模型

测试 1: 找出订单金额超过1000且包含超过5个商品的订单
------------------------------------------------------------
原始模型 (得分:3.5): ```sql
SELECT o.order_id, SUM(i.quantity * i.price) AS total_amount
FROM orders o
JOIN order_items i ON o.order_id = i.order_id
GROUP BY o.order_id
HA...
微调模型 (得分:3.5): SELECT o.order_id, o.customer_id, o.order_date, o.total_amount
FROM orders AS o
JOIN order_items AS oi ON o.order_id = oi.order_id
GROUP BY o.order_id...

测试 2: 找出工资比他们经理高的员工
------------------------------------------------------------
原始模型 (得分:2.0): ```sql
SELECT e1.name AS employee_name, e1.salary AS employee_salary, e2.name AS manager_name, e2.salary AS manager_salary 
FROM employees e1 
JOIN em...
微调模型 (得分:2.0): SELECT T1.name FROM employees AS T1 JOIN employees AS T2 ON T1.manager_id = T2.emp_id WHERE T1.salary > T2.salary;...

测试 3: 计算最近30天的滚动平均销售额
------------------------------------------------------------
原始模型 (得分:1.5): ```sql
SELECT 
    SUM(quantity * revenue) / COUNT(DISTINCT sale_