In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import json
import warnings

# 忽略特定的警告
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")

def load_sqlcoder_fixed():
    """
    修复版本的SQLCoder模型加载，使用自定义下载路径
    """
    print("正在加载SQLCoder模型...")
    
    try:
        # 设置自定义下载路径
        model_path = r"models--defog--sqlcoder-7b-2\snapshots\7e5b6f7981c0aa7d143f6bec6fa26625bdfcbe66"
        print(f"从本地路径 {model_path} 加载模型...")
        
        # 修改tokenizer配置
        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True,
            local_files_only=True
        )
        
        # 显式设置pad_token和pad_token_id
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        
        # 优化模型加载配置
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device_map="auto",
            local_files_only=True,
            use_flash_attention_2=False,  # 禁用flash attention
            low_cpu_mem_usage=True  # 优化内存使用
        )
        
        # 确保模型的pad_token_id与tokenizer一致
        model.config.pad_token_id = tokenizer.pad_token_id
        
        # 设置模型为评估模式
        model.eval()
        
        print("✅ 模型加载成功!")
        return tokenizer, model
        
    except Exception as e:
        print(f"❌ 模型加载失败: {e}")
        return None, None

def generate_sql_sqlcoder(sqlcoder_tokenizer, sqlcoder_model, question, schema=""):
    """
    使用SQLCoder生成SQL，优化生成过程
    
    Args:
        question: 自然语言问题
        schema: 数据库表结构（可选）
    """
    if sqlcoder_tokenizer is None or sqlcoder_model is None:
        return "模型未正确加载"
    
    # 构建prompt
    if schema:
        prompt = f"### Task\nGenerate a SQL query to answer this question: {question}\n\n### Database Schema\n{schema}\n\n### SQL\n"
    else:
        prompt = f"### Task\nGenerate a SQL query to answer this question: {question}\n\n### SQL\n"
    
    try:
        # 编码输入并添加attention mask
        inputs = sqlcoder_tokenizer(
            prompt,
            return_tensors="pt",
            padding=False,  # 保持padding为False
            truncation=True,
            max_length=512,
            return_attention_mask=True
        )
        
        # 将输入移动到正确的设备
        device = next(sqlcoder_model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # 生成SQL
        with torch.no_grad(), torch.amp.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):  # 使用自动混合精度
            outputs = sqlcoder_model.generate(
                **inputs,
                max_new_tokens=150,
                temperature=0.1,
                do_sample=True,
                pad_token_id=sqlcoder_tokenizer.pad_token_id,
                eos_token_id=sqlcoder_tokenizer.eos_token_id,
                num_return_sequences=1,
                repetition_penalty=1.2,
                use_cache=True  # 启用KV缓存
            )
        
        # 解码结果
        full_response = sqlcoder_tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 提取SQL部分
        if "### SQL" in full_response:
            sql = full_response.split("### SQL")[-1].strip()
        else:
            sql = full_response[len(prompt):].strip()
        
        return sql
        
    except Exception as e:
        print(f"生成SQL时出错: {e}")
        return f"生成失败: {str(e)}"

In [2]:
# 加载模型
sqlcoder_tokenizer, sqlcoder_model = load_sqlcoder_fixed()

正在加载SQLCoder模型...
从本地路径 models--defog--sqlcoder-7b-2\snapshots\7e5b6f7981c0aa7d143f6bec6fa26625bdfcbe66 加载模型...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


✅ 模型加载成功!


In [3]:
# 测试
if sqlcoder_tokenizer and sqlcoder_model:
    # 带表结构的测试
    with open('schema.txt', 'r') as f:
        schema = f.read()

    test_q2 = """
找出所有从不点任何东西的顾客。

以 任意顺序 返回结果表。

结果格式如下所示。



示例 1：

输入：
Customers table:
+----+-------+
| id | name  |
+----+-------+
| 1  | Joe   |
| 2  | Henry |
| 3  | Sam   |
| 4  | Max   |
+----+-------+
Orders table:
+----+------------+
| id | customerId |
+----+------------+
| 1  | 3          |
| 2  | 1          |
+----+------------+
输出：
+-----------+
| Customers |
+-----------+
| Henry     |
| Max       |
+-----------+
"""
    result2 = generate_sql_sqlcoder(sqlcoder_tokenizer, sqlcoder_model, test_q2, schema)
    print(f"问题: {test_q2}")
    print(f"SQL: {result2}")
else:
    print("❌ 模型加载失败，无法进行测试")

问题: 
找出所有从不点任何东西的顾客。

以 任意顺序 返回结果表。

结果格式如下所示。



示例 1：

输入：
Customers table:
+----+-------+
| id | name  |
+----+-------+
| 1  | Joe   |
| 2  | Henry |
| 3  | Sam   |
| 4  | Max   |
+----+-------+
Orders table:
+----+------------+
| id | customerId |
+----+------------+
| 1  | 3          |
| 2  | 1          |
+----+------------+
输出：
+-----------+
| Customers |
+-----------+
| Henry     |
| Max       |
+-----------+

SQL: SELECT c.name FROM Customers c LEFT JOIN Orders o ON c.id = o.customerId WHERE o.customerId IS NULL ORDER BY c.name NULLS LAST
