# Text-to-SQL: Fine-tuning & Evaluation on TPC-DS

This unified notebook provides:
- **Part 1**: QLoRA fine-tuning of Qwen for Text-to-SQL
- **Part 2**: Evaluation with execution accuracy metrics

**Target environment**: Kaggle (GPU T4/P100)


In [None]:
# Clone repository (Kaggle)
!git clone https://github.com/VuThanhLam124/Capstone-NLUS-VDD.git

In [None]:
%cd Capstone-NLUS-VDD

In [None]:
# Install dependencies
!pip install -r requirements.txt
!pip -q install -U "transformers>=4.43" "peft>=0.10" "bitsandbytes>=0.43" "accelerate>=0.30" "trl>=0.12" "datasets>=2.19" sqlglot

---
## Configuration

In [None]:
from pathlib import Path
import json
import os
import random
import time
import re
import gc
import math
import unicodedata
from decimal import Decimal
from datetime import date, datetime

import duckdb
import pandas as pd
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    set_seed,
)
from peft import PeftModel, PeftConfig, prepare_model_for_kbit_training

try:
    from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
except Exception:
    SFTTrainer = None
    DataCollatorForCompletionOnlyLM = None

try:
    import sqlglot
except Exception:
    sqlglot = None
    print("sqlglot not installed: SQL normalization will be simple.")

# ========== PATH CONFIG ==========
def find_repo_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "research_pipeline").exists():
            return p
    return start

REPO_ROOT = find_repo_root(Path.cwd())
DB_PATH = REPO_ROOT / "research_pipeline" / "data" / "ecommerce_dw.duckdb"

# Fine-tune data
FINETUNE_DATA_PATHS = [
    REPO_ROOT / "research_pipeline" / "data" / "data_finetune.csv",
    REPO_ROOT / "data" / "data_finetune.csv",
]

# Benchmark data
BENCHMARK_CANDIDATES = [
    REPO_ROOT / "research_pipeline" / "data" / "test_queries_vi_200_v2.json",
    REPO_ROOT / "research_pipeline" / "data" / "test_queries_vi_200.json",
    REPO_ROOT / "research_pipeline" / "test_queries.json",
]

OUTPUT_DIR = REPO_ROOT / "research_pipeline"

# ========== MODEL CONFIG ==========
BASE_ID = "Qwen/Qwen3-4B-Instruct-2507"
ADAPTER_ID = "Ellbendls/Qwen-3-4b-Text_to_SQL"
LOCAL_ADAPTER_DIR = REPO_ROOT / "research_pipeline" / "qwen_text_to_sql_lora_v2"

# ========== TRAINING CONFIG ==========
SEED = 42
MAX_FINETUNE_SAMPLES = None  # Set to int for quick debug
TRAIN_SPLIT = 0.9
MAX_TABLES = 8
MAX_SEQ_LEN = 768

BATCH_SIZE = 1
GRAD_ACCUM = 8
NUM_EPOCHS = 2
LEARNING_RATE = 2e-4
WARMUP_RATIO = 0.05

# ========== EVALUATION CONFIG ==========
MAX_EVAL_SAMPLES = 200  # Set None for full benchmark
SAMPLE_SEED = 42
DEFAULT_LIMIT = None
MAX_NEW_TOKENS = 256
NUM_BEAMS = 1

REPAIR_ON_ERROR = True
REPAIR_MAX_ATTEMPTS = 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_4BIT = torch.cuda.is_available()

set_seed(SEED)
print(f"Device: {DEVICE}")
print(f"Repo root: {REPO_ROOT}")

---
## Setup TPC-DS Database

In [None]:
AUTO_SETUP_DB = True
SETUP_SCALE_FACTOR = 1
FORCE_RECREATE_DB = False

def setup_tpcds_db(db_path: Path, scale_factor: int = 1, force_recreate: bool = False) -> None:
    db_path.parent.mkdir(parents=True, exist_ok=True)
    con = duckdb.connect(str(db_path))
    try:
        con.execute("INSTALL tpcds;")
        con.execute("LOAD tpcds;")

        tables = [r[0] for r in con.execute("SHOW TABLES").fetchall()]
        if tables and not force_recreate:
            print(f"Found {len(tables)} tables. Skip generation.")
            return

        if force_recreate and tables:
            for t in tables:
                con.execute(f"DROP TABLE {t}")

        print(f"Generating TPC-DS (sf={scale_factor})...")
        start = time.time()
        con.execute(f"CALL dsdgen(sf={scale_factor});")
        print(f"Data generation completed in {time.time() - start:.2f}s")
    finally:
        con.close()

if not DB_PATH.exists():
    if AUTO_SETUP_DB:
        setup_tpcds_db(DB_PATH, scale_factor=SETUP_SCALE_FACTOR, force_recreate=FORCE_RECREATE_DB)
    else:
        raise FileNotFoundError(f"TPC-DS DuckDB not found: {DB_PATH}")

---
## Schema Utilities

In [None]:
con = duckdb.connect(str(DB_PATH), read_only=True)

schema_map = {}
for (table_name,) in con.execute("SHOW TABLES").fetchall():
    columns = [r[0] for r in con.execute(f"DESCRIBE {table_name}").fetchall()]
    schema_map[table_name] = columns


def strip_accents(text: str) -> str:
    return "".join(
        ch for ch in unicodedata.normalize("NFD", text) if unicodedata.category(ch) != "Mn"
    )


def tokenize(text: str) -> list[str]:
    text = strip_accents(text.lower())
    raw_tokens = re.findall(r"[a-z0-9_]+", text)
    tokens = []
    for tok in raw_tokens:
        tokens.extend(tok.split("_"))
    return [t for t in tokens if len(t) > 1]


SYNONYMS = {
    "khach": "customer",
    "khachhang": "customer",
    "khach_hang": "customer",
    "sanpham": "item",
    "san_pham": "item",
    "hang": "item",
    "danhmuc": "category",
    "danh_muc": "category",
    "bang": "state",
    "tinh": "state",
    "cuahang": "store",
    "cua_hang": "store",
    "doanhthu": "revenue",
    "doanh_thu": "revenue",
    "soluong": "quantity",
    "so_luong": "quantity",
    "gia": "price",
    "thang": "month",
    "nam": "year",
    "quy": "quarter",
}


def expand_tokens(tokens: list[str]) -> set[str]:
    expanded = set(tokens)
    for tok in list(tokens):
        mapped = SYNONYMS.get(tok)
        if mapped:
            expanded.add(mapped)
    return expanded


table_tokens = {}
for table, cols in schema_map.items():
    tokens = set(tokenize(table))
    for col in cols:
        tokens.update(tokenize(col))
    table_tokens[table] = tokens


def select_tables_for_question(question: str, max_tables: int = 8) -> list[str]:
    q_tokens = expand_tokens(tokenize(question))
    scored = []
    for table, tokens in table_tokens.items():
        score = len(q_tokens & tokens)
        scored.append((score, table))
    scored.sort(reverse=True)

    selected = [t for score, t in scored if score > 0][:max_tables]

    def ensure(table: str):
        if table in schema_map and table not in selected:
            selected.append(table)

    if any(tok in q_tokens for tok in {"year", "month", "quarter", "date"}):
        ensure("date_dim")
    if any(tok in q_tokens for tok in {"customer"}):
        ensure("customer")
        ensure("customer_address")
    if "state" in q_tokens:
        ensure("customer_address")
        ensure("store")
    if "store" in q_tokens:
        ensure("store_sales")
        ensure("store")
    if "web" in q_tokens:
        ensure("web_sales")
        ensure("web_site")
    if "catalog" in q_tokens:
        ensure("catalog_sales")
        ensure("call_center")
    if "call" in q_tokens:
        ensure("call_center")
    if "inventory" in q_tokens:
        ensure("inventory")
    if any(tok in q_tokens for tok in {"item", "product", "category"}):
        ensure("item")
    if any(tok in q_tokens for tok in {"sales", "revenue", "quantity", "price"}):
        ensure("store_sales")

    return selected[: max_tables or len(selected)]


def build_schema_snippet(question: str, max_tables: int = 8) -> tuple[list[str], str]:
    tables = select_tables_for_question(question, max_tables=max_tables)
    if not tables:
        tables = list(schema_map.keys())

    lines = []
    for table in tables:
        cols = schema_map[table]
        lines.append(f"TABLE {table} (")
        for col in cols:
            lines.append(f"  {col}")
        lines.append(")")
        lines.append("")
    return tables, "".join(lines).strip()


print(f"Loaded schema for {len(schema_map)} tables.")

---
## SQL Utilities

In [None]:
_FORBIDDEN_SQL = re.compile(
    r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|COPY|PRAGMA|ATTACH|DETACH|EXPORT|IMPORT|CALL)\b",
    re.IGNORECASE,
)

SYSTEM_PROMPT = (
    "You translate user questions into SQL for DuckDB (TPC-DS). "
    "Return only SQL, no markdown, no explanations. "
    "Use only tables and columns from the schema."
)
REPAIR_PROMPT = (
    "You are fixing SQL for DuckDB (TPC-DS). "
    "Return only corrected SQL, no markdown, no explanations."
)


def extract_sql(text: str) -> str:
    text = text.strip()
    m = re.search(r"```(?:sql)?\s*(.*?)```", text, flags=re.IGNORECASE | re.DOTALL)
    if m:
        text = m.group(1).strip()
    if text.lower().startswith("sql:"):
        text = text[4:].strip()
    if ";" in text:
        text = text.split(";", 1)[0].strip()
    return text


def is_safe_select(sql: str) -> bool:
    s = re.sub(r"--.*?$", "", sql, flags=re.MULTILINE).strip()
    if not s:
        return False
    if _FORBIDDEN_SQL.search(s):
        return False
    first = re.split(r"\s+", s, maxsplit=1)[0].upper()
    return first in {"SELECT", "WITH"}


def ensure_limit(sql: str, limit: int | None) -> str:
    if limit is None:
        return sql
    s = sql.strip().rstrip(";").strip()
    if re.search(r"\bLIMIT\b", s, flags=re.IGNORECASE):
        return s
    return f"{s}\nLIMIT {limit}"


def has_order_by(sql: str) -> bool:
    return re.search(r"\border\s+by\b", sql, flags=re.IGNORECASE) is not None


def normalize_sql_text(sql: str) -> str:
    if sqlglot is not None:
        try:
            return sqlglot.parse_one(sql, read="duckdb").sql(dialect="duckdb", pretty=False)
        except Exception:
            pass
    return re.sub(r"\s+", " ", sql.strip()).lower()


def normalize_value(v):
    if isinstance(v, float):
        if math.isnan(v):
            return "nan"
        return round(v, 6)
    if isinstance(v, Decimal):
        return float(round(v, 6))
    if isinstance(v, (datetime, date)):
        return v.isoformat()
    return v


def normalize_rows(rows, keep_order: bool):
    if rows is None:
        return None
    norm = [tuple(normalize_value(x) for x in row) for row in rows]
    return norm if keep_order else sorted(norm)


def run_sql(con, sql: str):
    try:
        res = con.execute(sql).fetchall()
        return res, None
    except Exception as e:
        return None, str(e)


def classify_error(err: str | None) -> str | None:
    if err is None:
        return None
    if "Binder Error" in err:
        return "binder"
    if "Parser Error" in err:
        return "parser"
    if "Catalog Error" in err:
        return "catalog"
    return "exec_error"

---
# PART 1: Fine-tuning

In [None]:
RUN_FINETUNING = True  # Set to False to skip fine-tuning and use pre-trained adapter

In [None]:
if RUN_FINETUNING:
    # Load fine-tune dataset
    def resolve_data_path(paths):
        for p in paths:
            if p.exists():
                return p
        raise FileNotFoundError(f"No data_finetune.csv found in: {paths}")

    DATA_PATH = resolve_data_path(FINETUNE_DATA_PATHS)
    print("Using data:", DATA_PATH)

    df = pd.read_csv(DATA_PATH)
    df = df.dropna(subset=["Transcription", "SQL Ground Truth"]).copy()
    df["Transcription"] = df["Transcription"].astype(str).str.strip()
    df["SQL Ground Truth"] = df["SQL Ground Truth"].astype(str).str.strip()

    def normalize_sql_for_train(sql: str) -> str:
        sql = sql.strip()
        if not sql.endswith(";"):
            sql = sql + ";"
        return sql

    df["SQL Ground Truth"] = df["SQL Ground Truth"].map(normalize_sql_for_train)
    df = df.drop_duplicates(subset=["SQL Ground Truth"]).reset_index(drop=True)

    if MAX_FINETUNE_SAMPLES is not None:
        df = df.sample(n=min(MAX_FINETUNE_SAMPLES, len(df)), random_state=SEED).reset_index(drop=True)

    print("Rows:", len(df))
    df.head()

In [None]:
if RUN_FINETUNING:
    # Prepare dataset for training
    tokenizer = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=True, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def format_record(record) -> str:
        question = record["Transcription"]
        sql = record["SQL Ground Truth"]
        _, schema_text = build_schema_snippet(question, max_tables=MAX_TABLES)
        user = f"""SCHEMA:
{schema_text}

QUESTION:
{question}

SQL:"""

        if getattr(tokenizer, "chat_template", None):
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user},
                {"role": "assistant", "content": sql},
            ]
            return tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False,
            )

        return f"""{SYSTEM_PROMPT}

{user} {sql}"""

    dataset = Dataset.from_pandas(df)

    def map_fn(record):
        return {"text": format_record(record)}

    dataset = dataset.map(map_fn, remove_columns=dataset.column_names)
    dataset = dataset.shuffle(seed=SEED)

    train_size = int(len(dataset) * TRAIN_SPLIT)
    train_dataset = dataset.select(range(train_size))
    eval_dataset = dataset.select(range(train_size, len(dataset)))

    print("Train size:", len(train_dataset))
    print("Eval size:", len(eval_dataset))

In [None]:
if RUN_FINETUNING:
    # Load model with QLoRA
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

    config = AutoConfig.from_pretrained(BASE_ID, trust_remote_code=True)
    if len(tokenizer) != config.vocab_size:
        print(f"Adjust vocab_size {config.vocab_size} -> {len(tokenizer)}")
        config.vocab_size = len(tokenizer)

    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_ID,
        config=config,
        quantization_config=quant_config,
        device_map="auto",
        trust_remote_code=True,
        ignore_mismatched_sizes=True,
    )

    if len(tokenizer) != base_model.get_input_embeddings().weight.shape[0]:
        base_model.resize_token_embeddings(len(tokenizer))

    base_model = prepare_model_for_kbit_training(base_model)
    model = PeftModel.from_pretrained(base_model, ADAPTER_ID, is_trainable=True)

    model.print_trainable_parameters()
    model.config.use_cache = False

In [None]:
if RUN_FINETUNING:
    # Train
    if SFTTrainer is None:
        raise RuntimeError("TRL is required. Install trl and restart kernel.")

    training_args = TrainingArguments(
        output_dir=str(LOCAL_ADAPTER_DIR),
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUM,
        num_train_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE,
        warmup_ratio=WARMUP_RATIO,
        fp16=True,
        logging_steps=20,
        evaluation_strategy="steps",
        eval_steps=200,
        save_steps=200,
        save_total_limit=2,
        load_best_model_at_end=False,
        report_to="none",
        optim="paged_adamw_8bit",
    )

    response_template = "<|im_start|>assistant\n" if getattr(tokenizer, "chat_template", None) else "SQL:"

    data_collator = None
    if DataCollatorForCompletionOnlyLM is not None:
        data_collator = DataCollatorForCompletionOnlyLM(
            response_template=response_template,
            tokenizer=tokenizer,
        )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        dataset_text_field="text",
        max_seq_length=MAX_SEQ_LEN,
        packing=False,
        data_collator=data_collator,
    )

    train_result = trainer.train()
    print(train_result)

In [None]:
if RUN_FINETUNING:
    # Save adapter
    LOCAL_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
    trainer.model.save_pretrained(LOCAL_ADAPTER_DIR)
    tokenizer.save_pretrained(LOCAL_ADAPTER_DIR)
    print("Saved adapter to", LOCAL_ADAPTER_DIR)

    # Cleanup
    del model, base_model, trainer
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

---
# PART 2: Evaluation

In [None]:
# Choose adapter source
USE_LOCAL_ADAPTER = RUN_FINETUNING and LOCAL_ADAPTER_DIR.exists()

if USE_LOCAL_ADAPTER:
    EVAL_ADAPTER_ID = str(LOCAL_ADAPTER_DIR)
    print(f"Using local adapter: {EVAL_ADAPTER_ID}")
else:
    EVAL_ADAPTER_ID = ADAPTER_ID
    print(f"Using HuggingFace adapter: {EVAL_ADAPTER_ID}")

In [None]:
def load_model_and_tokenizer_for_eval():
    quant_config = BitsAndBytesConfig(load_in_4bit=True) if USE_4BIT else None

    peft_config = PeftConfig.from_pretrained(EVAL_ADAPTER_ID)
    base_id = peft_config.base_model_name_or_path or BASE_ID

    tokenizer = AutoTokenizer.from_pretrained(EVAL_ADAPTER_ID, use_fast=True, trust_remote_code=True)
    config = AutoConfig.from_pretrained(base_id, trust_remote_code=True)

    if len(tokenizer) != config.vocab_size:
        print(f"Adjusting vocab_size: {config.vocab_size} -> {len(tokenizer)}")
        config.vocab_size = len(tokenizer)

    model_kwargs = dict(
        config=config,
        device_map="auto" if DEVICE == "cuda" else None,
        trust_remote_code=True,
        ignore_mismatched_sizes=True,
    )
    if quant_config is not None:
        model_kwargs["quantization_config"] = quant_config

    model = AutoModelForCausalLM.from_pretrained(base_id, **model_kwargs)

    if len(tokenizer) != model.get_input_embeddings().weight.shape[0]:
        model.resize_token_embeddings(len(tokenizer))

    model = PeftModel.from_pretrained(model, EVAL_ADAPTER_ID)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.eval()

    return tokenizer, model


print("Loading model...")
tokenizer, model = load_model_and_tokenizer_for_eval()
print("Model loaded.")

In [None]:
# Load benchmark dataset
benchmark_path = None
for candidate in BENCHMARK_CANDIDATES:
    if candidate.exists():
        benchmark_path = candidate
        break

if benchmark_path is None:
    raise FileNotFoundError("No benchmark JSON found.")

raw_items = json.loads(benchmark_path.read_text())
items = []
for item in raw_items:
    question = item.get("text") or item.get("question")
    sql = item.get("sql")
    if not question or not sql:
        continue
    items.append({
        "id": item.get("id", f"q{len(items)+1}"),
        "text": question,
        "sql": sql,
    })

if MAX_EVAL_SAMPLES:
    random.seed(SAMPLE_SEED)
    items = random.sample(items, min(MAX_EVAL_SAMPLES, len(items)))

print(f"Benchmark items: {len(items)} from {benchmark_path}")

In [None]:
def build_prompt(question: str, schema_text: str) -> str:
    user = f"SCHEMA:\n{schema_text}\n\nQUESTION:\n{question}\n\nSQL:"
    if getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
    return f"{SYSTEM_PROMPT}\n\n{user}"


def build_repair_prompt(question: str, schema_text: str, bad_sql: str, error: str) -> str:
    user = (
        f"SCHEMA:\n{schema_text}\n\nQUESTION:\n{question}\n\n"
        f"BROKEN_SQL:\n{bad_sql}\n\nERROR:\n{error}\n\nFIXED_SQL:"
    )
    if getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            [
                {"role": "system", "content": REPAIR_PROMPT},
                {"role": "user", "content": user},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
    return f"{REPAIR_PROMPT}\n\n{user}"


def run_generation(prompt: str) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            num_beams=NUM_BEAMS,
            pad_token_id=pad_id,
        )
    gen_ids = output_ids[0][inputs["input_ids"].shape[1]:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return extract_sql(text)


def generate_sql_eval(question: str, schema_text: str) -> str:
    prompt = build_prompt(question, schema_text)
    sql = run_generation(prompt)
    return ensure_limit(sql, DEFAULT_LIMIT)


def repair_sql_eval(question: str, schema_text: str, bad_sql: str, error: str) -> str | None:
    prompt = build_repair_prompt(question, schema_text, bad_sql, error)
    sql = run_generation(prompt)
    return ensure_limit(sql, DEFAULT_LIMIT)

In [None]:
# Cache ground truth results
gt_cache = {}
for item in items:
    qid = item["id"]
    gt_sql = ensure_limit(item["sql"], DEFAULT_LIMIT)
    gt_res, gt_err = run_sql(con, gt_sql)
    gt_cache[qid] = {
        "sql": gt_sql,
        "res": gt_res,
        "err": gt_err,
        "has_order": has_order_by(gt_sql),
        "norm_sorted": normalize_rows(gt_res, keep_order=False) if gt_err is None else None,
        "norm_ordered": normalize_rows(gt_res, keep_order=True) if gt_err is None else None,
    }
print(f"Ground-truth cached: {len(gt_cache)}")

In [None]:
# Run evaluation
results = []
for idx, item in enumerate(items, 1):
    qid = item["id"]
    question = item["text"]
    gt = gt_cache[qid]

    schema_tables, schema_text = build_schema_snippet(question, max_tables=MAX_TABLES)

    start = time.time()
    gen_sql = generate_sql_eval(question, schema_text)
    gen_time = time.time() - start

    repair_used = False
    valid_sql = is_safe_select(gen_sql)
    if valid_sql:
        exec_start = time.time()
        gen_res, gen_err = run_sql(con, gen_sql)
        exec_time = time.time() - exec_start
    else:
        gen_res, gen_err, exec_time = None, "INVALID_SQL", None

    # Try repair on error
    if REPAIR_ON_ERROR and (not valid_sql or gen_err is not None) and REPAIR_MAX_ATTEMPTS > 0:
        repair_sql_text = repair_sql_eval(question, schema_text, gen_sql, gen_err)
        if repair_sql_text:
            repair_used = True
            gen_sql = repair_sql_text
            valid_sql = is_safe_select(gen_sql)
            if valid_sql:
                exec_start = time.time()
                gen_res, gen_err = run_sql(con, gen_sql)
                exec_time = time.time() - exec_start
            else:
                gen_res, gen_err, exec_time = None, "INVALID_SQL", None

    exact_match = False
    if valid_sql and gen_err is None:
        exact_match = normalize_sql_text(gen_sql) == normalize_sql_text(gt["sql"])

    exec_match = False
    if valid_sql and gen_err is None and gt["err"] is None:
        keep_order = gt["has_order"] or has_order_by(gen_sql)
        gt_norm = gt["norm_ordered"] if keep_order else gt["norm_sorted"]
        gen_norm = normalize_rows(gen_res, keep_order=keep_order)
        exec_match = gt_norm == gen_norm

    results.append({
        "id": qid,
        "question": question,
        "gt_sql": gt["sql"],
        "gen_sql": gen_sql,
        "valid_sql": valid_sql,
        "exact_match": exact_match,
        "exec_match": exec_match,
        "gen_error": gen_err,
        "gen_error_type": classify_error(gen_err),
        "gt_error": gt["err"],
        "gen_time_sec": gen_time,
        "exec_time_sec": exec_time,
        "schema_tables": ",".join(schema_tables),
        "schema_table_count": len(schema_tables),
        "repair_used": repair_used,
    })

    if idx % 10 == 0:
        print(f"Processed {idx}/{len(items)}")

results_df = pd.DataFrame(results)
print(f"\nCompleted: {len(results_df)} samples")

In [None]:
# Calculate metrics
valid_mask = results_df["valid_sql"]
exec_success = results_df["gen_error"].isna()
exec_acc_all = results_df["exec_match"].mean() if not results_df.empty else 0.0
valid_exec_mask = valid_mask & results_df["gt_error"].isna()
exec_acc_valid = results_df.loc[valid_exec_mask, "exec_match"].mean() if valid_exec_mask.any() else 0.0
exact_match_rate = results_df.loc[valid_mask, "exact_match"].mean() if valid_mask.any() else 0.0

summary = {
    "adapter": EVAL_ADAPTER_ID,
    "total_samples": len(results_df),
    "valid_sql_rate": float(valid_mask.mean()) if not results_df.empty else 0.0,
    "exec_success_rate": float(exec_success.mean()) if not results_df.empty else 0.0,
    "exec_accuracy_all": exec_acc_all,
    "exec_accuracy_valid": exec_acc_valid,
    "exact_match_rate": exact_match_rate,
    "avg_gen_time_sec": float(results_df["gen_time_sec"].mean()) if not results_df.empty else 0.0,
    "repair_used_count": int(results_df["repair_used"].sum()),
}

print("\n" + "="*50)
print("EVALUATION SUMMARY")
print("="*50)
for k, v in summary.items():
    if isinstance(v, float):
        print(f"{k}: {v:.4f}")
    else:
        print(f"{k}: {v}")

In [None]:
# Save results
output_csv = OUTPUT_DIR / "benchmark_text_to_sql_results.csv"
results_df.to_csv(output_csv, index=False)
print(f"Results saved to: {output_csv}")

# Show sample results
results_df[["id", "question", "valid_sql", "exec_match", "gen_error"]].head(10)

In [None]:
# Cleanup
del model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
con.close()
print("Done!")