<a href="https://colab.research.google.com/github/2018141043089/2018141043089.github.io/blob/master/text2sql/text2sql.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wget
!pip uninstall bitsandbytes
!pip install -U bitsandbytes

Collecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9656 sha256=bfdd2571d75e9ab42c918c1b568efd9980e180ba37a122da2f6fa450648abef2
  Stored in directory: /root/.cache/pip/wheels/8b/f1/7f/5c94f0a7a505ca1c81cd1d9208ae2064675d97582078e6c769
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
[0mCollecting bitsandbytes
  Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.4/122.4 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.44.1


In [2]:
!pip install datasets transformers peft torch pandas sqlalchemy

Collecting datasets
  Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)
Collecting peft
  Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.2-py3-none-any.whl (472 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.7/472.7 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading peft-0.13.2-py3-none-any.whl (320 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.7/320.7 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 

In [3]:
import bitsandbytes as bnb
import accelerate

import pandas as pd
import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)
from sqlalchemy import create_engine
from pathlib import Path
import wget
import zipfile

def prepare_synthea_data(data_dir):
    """准备Synthea数据集并创建SQLite数据库"""
    # 保持原有数据准备逻辑不变
    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)

    zip_path = data_dir / 'synthea_sample_data_csv_apr2020.zip'
    if not zip_path.exists():
        print("下载Synthea数据集...")
        synthea_url = "https://synthetichealth.github.io/synthea-sample-data/downloads/synthea_sample_data_csv_apr2020.zip"
        wget.download(synthea_url, str(zip_path))

    extract_dir = data_dir / 'synthea_data'
    if not extract_dir.exists():
        print("解压数据...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(str(extract_dir))

    db_path = data_dir / 'synthea_demo.db'
    engine = create_engine(f'sqlite:///{str(db_path)}')

    tables = ['patients', 'encounters', 'procedures']
    for table in tables:
        print(f"导入{table}表...")
        csv_path = extract_dir / 'csv' / f'{table}.csv'
        df = pd.read_csv(csv_path)
        df.to_sql(table, engine, if_exists='replace', index=False)

    return engine

def create_training_data():
    """创建训练数据集，使用更结构化的prompt模板"""
    print("准备训练数据...")
    synthea_data = [
        {"question": "How many unique patients are in the database?",
         "query": "SELECT COUNT(DISTINCT Id) FROM patients"},
        {"question": "What is the average age of patients?",
         "query": "SELECT AVG((julianday('now') - julianday(BIRTHDATE))/365.25) FROM patients"},
        {"question": "How many patients have had more than 5 encounters?",
         "query": "SELECT COUNT(*) FROM (SELECT PATIENT FROM encounters GROUP BY PATIENT HAVING COUNT(*) > 5)"},
        {"question": "What is the most common encounter type?",
         "query": "SELECT ENCOUNTERCLASS, COUNT(*) as count FROM encounters GROUP BY ENCOUNTERCLASS ORDER BY count DESC LIMIT 1"},
        {"question": "What percentage of patients are female?",
         "query": "SELECT (COUNT(*) * 100.0 / (SELECT COUNT(*) FROM patients)) as percentage FROM patients WHERE GENDER = 'F'"}
    ]

    # 加载更多Spider数据，但仍然控制数量以避免OOM
    try:
        print("加载Spider数据集...")
        spider_dataset = load_dataset("spider")
        spider_subset = spider_dataset['train'].select(range(100))  # 增加数据量到100

        spider_data = []
        for item in spider_subset:
            # 使用更结构化的prompt模板
            prompt = (
                "### Instruction: Convert the following question to SQL query.\n\n"
                f"### Question: {item['question']}\n\n"
                "### Response: Here's the SQL query:\n"
            )
            target = f"{item['query']}\n\n### End"
            spider_data.append({"prompt": prompt, "target": target})
    except Exception as e:
        print(f"加载Spider数据集时出错: {e}")
        spider_data = []

    synthea_formatted = []
    for item in synthea_data:
        prompt = (
            "### Instruction: Convert the following medical-related question to SQL query.\n\n"
            f"### Question: {item['question']}\n\n"
            "### Response: Here's the SQL query:\n"
        )
        target = f"{item['query']}\n\n### End"
        synthea_formatted.append({"prompt": prompt, "target": target})

    return spider_data + synthea_formatted

def setup_model_and_tokenizer():
    """设置模型和tokenizer，使用CodeLlama作为基础模型"""
    print("设置模型和tokenizer...")
    try:
        model_name = "codellama/CodeLlama-7b-hf"

        # 配置4bit量化
        from transformers import BitsAndBytesConfig
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
        )

        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            cache_dir="./model_cache"
        )

        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            cache_dir="./model_cache"
        )

        # 准备模型进行4bit训练
        model = prepare_model_for_kbit_training(model)

        # 配置LoRA
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],  # CodeLlama特定的target modules
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )

        model = get_peft_model(model, lora_config)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.pad_token_id

        return model, tokenizer
    except Exception as e:
        print(f"设置模型时出错: {e}")
        raise

def preprocess_function(examples, tokenizer):
    """改进的预处理函数"""
    max_length = 512

    texts = []
    for prompt, target in zip(examples["prompt"], examples["target"]):
        combined_text = f"{prompt}{target}"
        texts.append(combined_text)

    encoded = tokenizer(
        texts,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors=None
    )

    encoded["labels"] = encoded["input_ids"].copy()

    return encoded

def main():
    try:
        if not torch.cuda.is_available():
            raise RuntimeError("This script requires a GPU to run effectively")

        print("使用设备: cuda")

        base_dir = Path("./sql_training")
        base_dir.mkdir(exist_ok=True)

        print("准备数据...")
        engine = prepare_synthea_data(base_dir / "data")
        combined_training_data = create_training_data()

        model, tokenizer = setup_model_and_tokenizer()

        print("处理数据集...")
        dataset = Dataset.from_dict({
            "prompt": [item['prompt'] for item in combined_training_data],
            "target": [item['target'] for item in combined_training_data]
        })

        processed_dataset = dataset.map(
            lambda x: preprocess_function(x, tokenizer),
            batched=True,
            remove_columns=dataset.column_names
        )

        print("配置训练参数...")
        training_args = TrainingArguments(
            output_dir=str(base_dir / "results"),
            run_name="sql_training_run",
            num_train_epochs=3,
            per_device_train_batch_size=2,
            gradient_accumulation_steps=4,
            learning_rate=2e-4,
            fp16=True,
            logging_steps=10,
            save_steps=100,
            warmup_ratio=0.05,
            weight_decay=0.01,
            max_grad_norm=0.3,
            lr_scheduler_type="cosine",
            report_to=["none"],
        )

        print("初始化训练器...")
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=processed_dataset,
            data_collator=data_collator,
        )

        print("开始训练...")
        trainer.train()

        print("保存模型...")
        save_dir = base_dir / "fine_tuned_model"
        model.save_pretrained(str(save_dir))
        tokenizer.save_pretrained(str(save_dir))

        print("训练完成！")

    except Exception as e:
        print(f"运行过程中出错: {e}")
        raise

if __name__ == "__main__":
    main()

使用设备: cuda
准备数据...
下载Synthea数据集...
解压数据...
导入patients表...
导入encounters表...
导入procedures表...
准备训练数据...
加载Spider数据集...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/5.51k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/831k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/126k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1034 [00:00<?, ? examples/s]

设置模型和tokenizer...


tokenizer_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/637 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

处理数据集...


Map:   0%|          | 0/105 [00:00<?, ? examples/s]

配置训练参数...
初始化训练器...
开始训练...


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss
10,1.9076
20,1.0366
30,0.8169


config.json:   0%|          | 0.00/637 [00:00<?, ?B/s]

保存模型...
训练完成！


    test_questions = [
        "How many patients are there in total?",
        "What is the average age of female patients?",
        "What are the top 3 most common procedures?",
        "How many encounters were there in the year 2020?",
        "What is the distribution of patient genders?",
        "Show me the oldest patient's details",
        "What is the average length of hospital stays?",
        "How many patients have diabetes?",
        "What is the most common age group for hospital visits?",
        "Which day of the week has the most patient encounters?"
    ]

In [10]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from difflib import SequenceMatcher

class SQLGenerationTester:
    def __init__(self, model_path: str):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        # 加载tokenizer和模型
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            trust_remote_code=True,
            load_in_4bit=True
        )
        self.model.eval()

    def generate_sql(self, question: str) -> str:
        prompt = (
            "### Instruction: Convert the following medical-related question to SQL query.\n\n"
            f"### Question: {question}\n\n"
            "### Response: Here's the SQL query:\n"
        )

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=200,
                num_return_sequences=1,
                temperature=0.1,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id
            )

        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        try:
            sql_query = generated_text.split("Here's the SQL query:\n")[1].split("\n\n### End")[0].strip()
        except IndexError:
            sql_query = "Error: Could not parse generated SQL query"

        return sql_query

    def normalize_sql(self, sql: str) -> str:
        """规范化SQL查询"""
        return sql.strip().upper()

    def calculate_similarity(self, sql1: str, sql2: str) -> float:
        """计算两个SQL查询的相似度"""
        return SequenceMatcher(None, sql1, sql2).ratio()

    def test_queries(self, test_cases: List[Dict[str, str]]) -> List[Dict]:
        results = []

        for case in test_cases:
            question = case['question']
            print(f"\nTesting question: {question}")

            generated_sql = self.generate_sql(question)
            print(f"Generated SQL: {generated_sql}")

            normalized_generated = self.normalize_sql(generated_sql)
            normalized_expected = self.normalize_sql(case.get('expected_sql', ''))

            similarity = self.calculate_similarity(normalized_generated, normalized_expected)

            test_result = {
                'question': question,
                'generated_sql': generated_sql,
                'expected_sql': case.get('expected_sql', 'Not provided'),
                'sql_similarity': similarity
            }

            print(f"SQL similarity: {similarity:.2f}")

            results.append(test_result)

        return results

def main():
    # 测试用例
    test_cases = [
        {
            'question': "How many patients are older than 65 years?",
            'expected_sql': "SELECT COUNT(*) FROM patients WHERE (julianday('now') - julianday(BIRTHDATE))/365.25 > 65"
        },
        {
            'question': "What is the average length of hospital stays for emergency visits?",
            'expected_sql': "SELECT AVG(julianday(STOP) - julianday(START)) FROM encounters WHERE ENCOUNTERCLASS = 'emergency'"
        },
        {
            'question': "List the top 5 most common procedures performed",
            'expected_sql': "SELECT DESCRIPTION, COUNT(*) as count FROM procedures GROUP BY DESCRIPTION ORDER BY count DESC LIMIT 5"
        },
        {
            'question': "How many patients have diabetes?",
            'expected_sql': "SELECT COUNT(DISTINCT PATIENT) FROM encounters WHERE DESCRIPTION LIKE '%diabetes%'"
        },
        {
            'question': "What is the gender distribution of patients over 50?",
            'expected_sql': "SELECT GENDER, COUNT(*) as count FROM patients WHERE (julianday('now') - julianday(BIRTHDATE))/365.25 > 50 GROUP BY GENDER"
        }
    ]

    model_path = "./sql_training/fine_tuned_model"

    tester = SQLGenerationTester(model_path)

    print("Starting SQL generation tests...")
    results = tester.test_queries(test_cases)

    # 计算平均相似度
    avg_similarity = sum(r['sql_similarity'] for r in results) / len(results)

    print(f"\nTest Summary:")
    print(f"Total test cases: {len(results)}")
    print(f"Average SQL similarity: {avg_similarity:.2f}")

if __name__ == "__main__":
    main()

Using device: cuda


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Starting SQL generation tests...

Testing question: How many patients are older than 65 years?
Generated SQL: SELECT COUNT(*) FROM patient WHERE age > 65
SQL similarity: 0.64

Testing question: What is the average length of hospital stays for emergency visits?
Generated SQL: SELECT  AVG(hospital_stay_length)  FROM  emergency_visits
SQL similarity: 0.47

Testing question: List the top 5 most common procedures performed
Generated SQL: SELECT TOP 5 Procedure FROM Procedure ORDER BY Procedure DESC
SQL similarity: 0.55

Testing question: How many patients have diabetes?
Generated SQL: SELECT COUNT(*) FROM patient WHERE diabetes = 1
SQL similarity: 0.56

Testing question: What is the gender distribution of patients over 50?
Generated SQL: SELECT gender FROM patients WHERE age > 50
SQL similarity: 0.50

Test Summary:
Total test cases: 5
Average SQL similarity: 0.54
