# TPC-DS Text-to-SQL Execution Benchmark (Qwen Baseline)

This notebook benchmarks text-to-SQL models on TPC-DS and reports execution accuracy.


In [None]:
!git clone https://github.com/VuThanhLam124/Capstone-NLUS-VDD.git

In [None]:
cd Capstone-NLUS-VDD

In [None]:
!pip install -r requirements.txt
!pip -q install sqlglot

In [None]:
from pathlib import Path
import json
import os
import random
import time
import re
import gc
import math
import unicodedata
from decimal import Decimal
from datetime import date, datetime

import duckdb
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig


def find_repo_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "research_pipeline").exists():
            return p
    return start

REPO_ROOT = find_repo_root(Path.cwd())
DB_PATH = REPO_ROOT / "research_pipeline" / "data" / "ecommerce_dw.duckdb"
BENCHMARK_CANDIDATES = [
    REPO_ROOT / "research_pipeline" / "data" / "test_queries_vi_200_v2.json",
    REPO_ROOT / "research_pipeline" / "data" / "test_queries_vi_200.json",
    REPO_ROOT / "research_pipeline" / "test_queries.json",
]
OUTPUT_DIR = REPO_ROOT / "research_pipeline"
RUN_ID = None  # set like "run1" or time.strftime("%Y%m%d_%H%M%S")

MODEL_CHOICES = {
    "qwen_3_4b_text_to_sql": {
        "type": "lora_causal",
        "adapter_id": "Ellbendls/Qwen-3-4b-Text_to_SQL",
        "base_id": "Qwen/Qwen3-4B-Instruct-2507",
        "tokenizer_id": "Ellbendls/Qwen-3-4b-Text_to_SQL",
        "allow_vocab_shrink": True,
    },
}

MODEL_ORDER = ["qwen_3_4b_text_to_sql"]
RUN_ALL_MODELS = False
MODEL_CHOICE = "qwen_3_4b_text_to_sql"
CONTINUE_ON_ERROR = True

MAX_SAMPLES = 200  # set None to run full benchmark
SAMPLE_SEED = 42
DEFAULT_LIMIT = None  # set to an int to force LIMIT on both GT and generated SQL
MAX_TABLES = 8
MAX_NEW_TOKENS = 256
NUM_BEAMS = 1

REPAIR_ON_ERROR = True
REPAIR_MAX_ATTEMPTS = 1
REPAIR_ONLY_FOR = {"qwen_3_4b_text_to_sql"}

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_4BIT = torch.cuda.is_available()
print(f"Using device: {DEVICE}")


def make_output_path(stem: str) -> Path:
    suffix = f"_{RUN_ID}" if RUN_ID else ""
    return OUTPUT_DIR / f"{stem}{suffix}.csv"

OUTPUT_CSV_ALL = make_output_path("benchmark_text_to_sql_all")


In [None]:
AUTO_SETUP_DB = True
SETUP_SCALE_FACTOR = 1
FORCE_RECREATE_DB = False

def setup_tpcds_db(db_path: Path, scale_factor: int = 1, force_recreate: bool = False) -> None:
    db_path.parent.mkdir(parents=True, exist_ok=True)
    con = duckdb.connect(str(db_path))
    try:
        con.execute("INSTALL tpcds;")
        con.execute("LOAD tpcds;")

        tables = [r[0] for r in con.execute("SHOW TABLES").fetchall()]
        if tables and not force_recreate:
            print(f"Found {len(tables)} tables. Skip generation.")
            return

        if force_recreate and tables:
            for t in tables:
                con.execute(f"DROP TABLE {t}")

        print(f"Generating TPC-DS (sf={scale_factor})...")
        start = time.time()
        con.execute(f"CALL dsdgen(sf={scale_factor});")
        print(f"Data generation completed in {time.time() - start:.2f}s")
    finally:
        con.close()

if not DB_PATH.exists():
    if AUTO_SETUP_DB:
        setup_tpcds_db(DB_PATH, scale_factor=SETUP_SCALE_FACTOR, force_recreate=FORCE_RECREATE_DB)
    else:
        raise FileNotFoundError(f"TPC-DS DuckDB not found: {DB_PATH}")


In [None]:
con = duckdb.connect(str(DB_PATH), read_only=True)

schema_map = {}
for (table_name,) in con.execute("SHOW TABLES").fetchall():
    columns = [r[0] for r in con.execute(f"DESCRIBE {table_name}").fetchall()]
    schema_map[table_name] = columns


def strip_accents(text: str) -> str:
    return "".join(
        ch for ch in unicodedata.normalize("NFD", text) if unicodedata.category(ch) != "Mn"
    )


def tokenize(text: str) -> list[str]:
    text = strip_accents(text.lower())
    raw_tokens = re.findall(r"[a-z0-9_]+", text)
    tokens = []
    for tok in raw_tokens:
        tokens.extend(tok.split("_"))
    return [t for t in tokens if len(t) > 1]


SYNONYMS = {
    "khach": "customer",
    "khachhang": "customer",
    "khach_hang": "customer",
    "khachhangs": "customer",
    "sanpham": "item",
    "san_pham": "item",
    "hang": "item",
    "danhmuc": "category",
    "danh_muc": "category",
    "bang": "state",
    "tinh": "state",
    "cuahang": "store",
    "cua_hang": "store",
    "doanhthu": "revenue",
    "doanh_thu": "revenue",
    "soluong": "quantity",
    "so_luong": "quantity",
    "gia": "price",
    "thang": "month",
    "nam": "year",
    "quy": "quarter",
}


def expand_tokens(tokens: list[str]) -> set[str]:
    expanded = set(tokens)
    for tok in list(tokens):
        mapped = SYNONYMS.get(tok)
        if mapped:
            expanded.add(mapped)
    return expanded


table_tokens = {}
for table, cols in schema_map.items():
    tokens = set(tokenize(table))
    for col in cols:
        tokens.update(tokenize(col))
    table_tokens[table] = tokens


def select_tables_for_question(question: str, max_tables: int = 8) -> list[str]:
    q_tokens = expand_tokens(tokenize(question))
    scored = []
    for table, tokens in table_tokens.items():
        score = len(q_tokens & tokens)
        scored.append((score, table))
    scored.sort(reverse=True)

    selected = [t for score, t in scored if score > 0][:max_tables]

    def ensure(table: str):
        if table in schema_map and table not in selected:
            selected.append(table)

    if any(tok in q_tokens for tok in {"year", "month", "quarter", "date"}):
        ensure("date_dim")
    if any(tok in q_tokens for tok in {"customer"}):
        ensure("customer")
        ensure("customer_address")
    if "state" in q_tokens:
        ensure("customer_address")
        ensure("store")
    if "store" in q_tokens:
        ensure("store_sales")
        ensure("store")
    if "web" in q_tokens:
        ensure("web_sales")
        ensure("web_site")
    if "catalog" in q_tokens:
        ensure("catalog_sales")
        ensure("call_center")
    if "call" in q_tokens:
        ensure("call_center")
    if "inventory" in q_tokens:
        ensure("inventory")
    if any(tok in q_tokens for tok in {"item", "product", "category"}):
        ensure("item")
    if any(tok in q_tokens for tok in {"sales", "revenue", "quantity", "price"}):
        ensure("store_sales")

    return selected[: max_tables or len(selected)]


def build_schema_snippet(question: str, max_tables: int = 8) -> tuple[list[str], str]:
    tables = select_tables_for_question(question, max_tables=max_tables)
    if not tables:
        tables = list(schema_map.keys())

    lines = []
    for table in tables:
        cols = schema_map[table]
        lines.append(f"TABLE {table} (")
        for col in cols:
            lines.append(f"  {col}")
        lines.append(")")
        lines.append("")
    return tables, "".join(lines).strip()

print(f"Loaded schema for {len(schema_map)} tables.")


In [None]:
benchmark_path = None
for candidate in BENCHMARK_CANDIDATES:
    if candidate.exists():
        benchmark_path = candidate
        break

if benchmark_path is None:
    raise FileNotFoundError("No benchmark JSON found.")

raw_items = json.loads(benchmark_path.read_text())
items = []
for item in raw_items:
    question = item.get("text") or item.get("question")
    sql = item.get("sql")
    if not question or not sql:
        continue
    items.append({
        "id": item.get("id", f"q{len(items)+1}"),
        "text": question,
        "sql": sql,
    })

if MAX_SAMPLES:
    random.seed(SAMPLE_SEED)
    items = random.sample(items, min(MAX_SAMPLES, len(items)))

print(f"Benchmark items: {len(items)} from {benchmark_path}")


In [None]:
try:
    import sqlglot
except Exception:
    sqlglot = None
    print("sqlglot not installed: SQL normalization will be simple.")


In [None]:
def load_model_and_tokenizer(spec: dict):
    model_type = spec["type"]
    quant_config = BitsAndBytesConfig(load_in_4bit=True) if USE_4BIT else None

    def resolve_tokenizer_and_config(model_id: str, tokenizer_id: str | None = None, allow_shrink: bool = False):
        tokenizer_id = tokenizer_id or model_id
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, use_fast=True, trust_remote_code=True)
        config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        if len(tokenizer) != config.vocab_size and (len(tokenizer) > config.vocab_size or allow_shrink):
            print(f"Adjusting vocab_size for {model_id}: {config.vocab_size} -> {len(tokenizer)}")
            config.vocab_size = len(tokenizer)
        return tokenizer, config

    if model_type == "seq2seq":
        tokenizer, config = resolve_tokenizer_and_config(
            spec["id"], spec.get("tokenizer_id"), spec.get("allow_vocab_shrink", False)
        )
        model = AutoModelForSeq2SeqLM.from_pretrained(
            spec["id"],
            config=config,
            device_map="auto" if DEVICE == "cuda" else None,
            dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
            trust_remote_code=True,
        )
        model_kind = "seq2seq"
        model_id = spec["id"]
    elif model_type == "lora_causal":
        from peft import PeftConfig, PeftModel
        adapter_id = spec["adapter_id"]
        peft_config = PeftConfig.from_pretrained(adapter_id)
        base_id = spec.get("base_id") or peft_config.base_model_name_or_path
        tokenizer, config = resolve_tokenizer_and_config(
            base_id, spec.get("tokenizer_id"), spec.get("allow_vocab_shrink", False)
        )
        model_kwargs = dict(
            config=config,
            device_map="auto" if DEVICE == "cuda" else None,
            dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
            trust_remote_code=True,
            ignore_mismatched_sizes=True,
        )
        if quant_config is not None:
            model_kwargs["quantization_config"] = quant_config
        model = AutoModelForCausalLM.from_pretrained(base_id, **model_kwargs)
        if spec.get("allow_vocab_shrink", False) and len(tokenizer) != model.get_input_embeddings().weight.shape[0]:
            model.resize_token_embeddings(len(tokenizer))
        model = PeftModel.from_pretrained(model, adapter_id)
        model_kind = "causal"
        model_id = f"{base_id} + {adapter_id}"
    else:
        tokenizer, config = resolve_tokenizer_and_config(
            spec["id"], spec.get("tokenizer_id"), spec.get("allow_vocab_shrink", False)
        )
        model_kwargs = dict(
            config=config,
            device_map="auto" if DEVICE == "cuda" else None,
            dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
            trust_remote_code=True,
            ignore_mismatched_sizes=True,
        )
        if quant_config is not None:
            model_kwargs["quantization_config"] = quant_config
        model = AutoModelForCausalLM.from_pretrained(spec["id"], **model_kwargs)
        model_kind = "causal"
        model_id = spec["id"]

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.eval()
    return tokenizer, model, model_kind, model_id


In [None]:
_FORBIDDEN_SQL = re.compile(
    r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|COPY|PRAGMA|ATTACH|DETACH|EXPORT|IMPORT|CALL)\b",
    re.IGNORECASE,
)

SYSTEM_PROMPT = (
    "You translate user questions into SQL for DuckDB (TPC-DS). "
    "Return only SQL, no markdown, no explanations. "
    "Use only tables and columns from the schema."
)
REPAIR_PROMPT = (
    "You are fixing SQL for DuckDB (TPC-DS). "
    "Return only corrected SQL, no markdown, no explanations."
)
SEQ2SEQ_PROMPT_TEMPLATE = "translate to SQL:\n{question}\n\nSCHEMA:\n{schema}\n\nSQL:"

def extract_sql(text: str) -> str:
    text = text.strip()
    m = re.search(r"```(?:sql)?\s*(.*?)```", text, flags=re.IGNORECASE | re.DOTALL)
    if m:
        text = m.group(1).strip()
    if text.lower().startswith("sql:"):
        text = text[4:].strip()
    if ";" in text:
        text = text.split(";", 1)[0].strip()
    return text

def is_safe_select(sql: str) -> bool:
    s = re.sub(r"--.*?$", "", sql, flags=re.MULTILINE).strip()
    if not s:
        return False
    if _FORBIDDEN_SQL.search(s):
        return False
    first = re.split(r"\s+", s, maxsplit=1)[0].upper()
    return first in {"SELECT", "WITH"}

def ensure_limit(sql: str, limit: int | None) -> str:
    if limit is None:
        return sql
    s = sql.strip().rstrip(";").strip()
    if re.search(r"\bLIMIT\b", s, flags=re.IGNORECASE):
        return s
    return f"{s}\nLIMIT {limit}"

def has_order_by(sql: str) -> bool:
    return re.search(r"\border\s+by\b", sql, flags=re.IGNORECASE) is not None

def normalize_sql(sql: str) -> str:
    if sqlglot is not None:
        try:
            return sqlglot.parse_one(sql, read="duckdb").sql(dialect="duckdb", pretty=False)
        except Exception:
            pass
    return re.sub(r"\s+", " ", sql.strip()).lower()

def build_prompt(question: str, schema_text: str, tokenizer, model_kind: str) -> str:
    if model_kind == "seq2seq":
        return SEQ2SEQ_PROMPT_TEMPLATE.format(question=question, schema=schema_text)
    user = f"SCHEMA:\n{schema_text}\n\nQUESTION:\n{question}\n\nSQL:"
    if getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
    return f"{SYSTEM_PROMPT}\n\n{user}"

def build_repair_prompt(question: str, schema_text: str, bad_sql: str, error: str, tokenizer) -> str:
    user = (
        f"SCHEMA:\n{schema_text}\n\nQUESTION:\n{question}\n\n"
        f"BROKEN_SQL:\n{bad_sql}\n\nERROR:\n{error}\n\nFIXED_SQL:"
    )
    if getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            [
                {"role": "system", "content": REPAIR_PROMPT},
                {"role": "user", "content": user},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
    return f"{REPAIR_PROMPT}\n\n{user}"

def run_generation(prompt: str, tokenizer, model, model_kind: str) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            num_beams=NUM_BEAMS,
            pad_token_id=pad_id,
        )
    if model_kind == "seq2seq":
        text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    else:
        gen_ids = output_ids[0][inputs["input_ids"].shape[1]:]
        text = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return extract_sql(text)

def generate_sql(question: str, schema_text: str, tokenizer, model, model_kind: str) -> str:
    prompt = build_prompt(question, schema_text, tokenizer, model_kind)
    sql = run_generation(prompt, tokenizer, model, model_kind)
    return ensure_limit(sql, DEFAULT_LIMIT)

def repair_sql(question: str, schema_text: str, bad_sql: str, error: str, tokenizer, model, model_kind: str) -> str | None:
    if model_kind == "seq2seq":
        return None
    prompt = build_repair_prompt(question, schema_text, bad_sql, error, tokenizer)
    sql = run_generation(prompt, tokenizer, model, model_kind)
    return ensure_limit(sql, DEFAULT_LIMIT)


In [None]:
def normalize_value(v):
    if isinstance(v, float):
        if math.isnan(v):
            return "nan"
        return round(v, 6)
    if isinstance(v, Decimal):
        return float(round(v, 6))
    if isinstance(v, (datetime, date)):
        return v.isoformat()
    return v

def normalize_rows(rows, keep_order: bool):
    if rows is None:
        return None
    norm = [tuple(normalize_value(x) for x in row) for row in rows]
    return norm if keep_order else sorted(norm)

def run_sql(con, sql: str):
    try:
        res = con.execute(sql).fetchall()
        return res, None
    except Exception as e:
        return None, str(e)

def classify_error(err: str | None) -> str | None:
    if err is None:
        return None
    if "Binder Error" in err:
        return "binder"
    if "Parser Error" in err:
        return "parser"
    if "Catalog Error" in err:
        return "catalog"
    return "exec_error"


In [None]:
gt_cache = {}
for item in items:
    qid = item["id"]
    gt_sql = ensure_limit(item["sql"], DEFAULT_LIMIT)
    gt_res, gt_err = run_sql(con, gt_sql)
    gt_cache[qid] = {
        "sql": gt_sql,
        "res": gt_res,
        "err": gt_err,
        "has_order": has_order_by(gt_sql),
        "norm_sorted": normalize_rows(gt_res, keep_order=False) if gt_err is None else None,
        "norm_ordered": normalize_rows(gt_res, keep_order=True) if gt_err is None else None,
    }
print(f"Ground-truth cached: {len(gt_cache)}")


In [None]:
def run_benchmark_for_model(model_choice: str):
    spec = MODEL_CHOICES[model_choice]
    print(f"Loading model choice: {model_choice}")
    tokenizer, model, model_kind, model_id = load_model_and_tokenizer(spec)

    results = []
    for idx, item in enumerate(items, 1):
        qid = item["id"]
        question = item["text"]
        gt = gt_cache[qid]

        schema_tables, schema_text = build_schema_snippet(question, max_tables=MAX_TABLES)

        start = time.time()
        gen_sql = generate_sql(question, schema_text, tokenizer, model, model_kind)
        gen_time = time.time() - start

        repair_used = False
        valid_sql = is_safe_select(gen_sql)
        if valid_sql:
            exec_start = time.time()
            gen_res, gen_err = run_sql(con, gen_sql)
            exec_time = time.time() - exec_start
        else:
            gen_res, gen_err, exec_time = None, "INVALID_SQL", None

        if (
            REPAIR_ON_ERROR
            and model_choice in REPAIR_ONLY_FOR
            and (not valid_sql or gen_err is not None)
            and REPAIR_MAX_ATTEMPTS > 0
        ):
            repair_sql_text = repair_sql(question, schema_text, gen_sql, gen_err, tokenizer, model, model_kind)
            if repair_sql_text:
                repair_used = True
                gen_sql = repair_sql_text
                valid_sql = is_safe_select(gen_sql)
                if valid_sql:
                    exec_start = time.time()
                    gen_res, gen_err = run_sql(con, gen_sql)
                    exec_time = time.time() - exec_start
                else:
                    gen_res, gen_err, exec_time = None, "INVALID_SQL", None

        exact_match = False
        if valid_sql and gen_err is None:
            exact_match = normalize_sql(gen_sql) == normalize_sql(gt["sql"])

        exec_match = False
        if valid_sql and gen_err is None and gt["err"] is None:
            keep_order = gt["has_order"] or has_order_by(gen_sql)
            gt_norm = gt["norm_ordered"] if keep_order else gt["norm_sorted"]
            gen_norm = normalize_rows(gen_res, keep_order=keep_order)
            exec_match = gt_norm == gen_norm

        results.append({
            "id": qid,
            "question": question,
            "gt_sql": gt["sql"],
            "gen_sql": gen_sql,
            "valid_sql": valid_sql,
            "exact_match": exact_match,
            "exec_match": exec_match,
            "gen_error": gen_err,
            "gen_error_type": classify_error(gen_err),
            "gt_error": gt["err"],
            "gen_time_sec": gen_time,
            "exec_time_sec": exec_time,
            "model_choice": model_choice,
            "model_id": model_id,
            "schema_tables": ",".join(schema_tables),
            "schema_table_count": len(schema_tables),
            "repair_used": repair_used,
        })

        if idx % 10 == 0:
            print(f"Processed {idx}/{len(items)}")

    results_df = pd.DataFrame(results)

    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return results_df


In [None]:
model_choices = MODEL_ORDER if RUN_ALL_MODELS else [MODEL_CHOICE]
all_results = []
summary_rows = []

for choice in model_choices:
    try:
        results_df = run_benchmark_for_model(choice)
    except Exception as e:
        print(f"Model {choice} failed: {e}")
        if CONTINUE_ON_ERROR:
            continue
        raise

    out_path = make_output_path(f"benchmark_text_to_sql_{choice}")
    results_df.to_csv(out_path, index=False)
    all_results.append(results_df)

    valid_mask = results_df["valid_sql"]
    exec_success = results_df["gen_error"].isna()
    exec_acc_all = results_df["exec_match"].mean() if not results_df.empty else 0.0
    valid_exec_mask = valid_mask & results_df["gt_error"].isna()
    exec_acc_valid = results_df.loc[valid_exec_mask, "exec_match"].mean() if valid_exec_mask.any() else 0.0
    exact_match_rate = results_df.loc[valid_mask, "exact_match"].mean() if valid_mask.any() else 0.0

    summary_rows.append({
        "model_choice": choice,
        "model_id": results_df["model_id"].iloc[0] if not results_df.empty else None,
        "total": len(results_df),
        "valid_sql_rate": float(valid_mask.mean()) if not results_df.empty else 0.0,
        "exec_success_rate": float(exec_success.mean()) if not results_df.empty else 0.0,
        "exec_acc_all": exec_acc_all,
        "exec_acc_valid": exec_acc_valid,
        "exact_match_rate": exact_match_rate,
        "avg_gen_time_sec": float(results_df["gen_time_sec"].mean()) if not results_df.empty else 0.0,
        "avg_exec_time_sec": float(results_df["exec_time_sec"].dropna().mean()) if results_df["exec_time_sec"].notna().any() else 0.0,
        "invalid_sql": int((results_df["gen_error"] == "INVALID_SQL").sum()),
        "gen_exec_errors": int(results_df["gen_error"].notna().sum()),
        "gt_exec_errors": int(results_df["gt_error"].notna().sum()),
        "output_csv": str(out_path),
    })

if not all_results:
    raise RuntimeError("No model results produced.")

combined_df = pd.concat(all_results, ignore_index=True)
combined_df.to_csv(OUTPUT_CSV_ALL, index=False)

summary_df = pd.DataFrame(summary_rows)
print("Summary")
print(summary_df)
print(f"Combined results saved to: {OUTPUT_CSV_ALL}")
