# Research Benchmark: Text-to-SQL Evaluation

**Purpose**: Evaluate 6 experimental conditions and generate results for paper.

| Condition | Description |
|-----------|-------------|
| B0 | Baseline: Qwen + Full Schema |
| B1 | +Dynamic Schema Selection |
| B2 | +Fine-tuned (QLoRA) |
| B3 | Full: Dynamic + Fine-tuned + Repair |
| B4 | +Schema Enrichment (Dynamic + Content) |
| B5 | +RAG Few-shot (Dynamic + Examples) |


In [None]:
# Kaggle setup
!git clone https://github.com/VuThanhLam124/Capstone-NLUS-VDD.git 2>/dev/null || true
%cd Capstone-NLUS-VDD
!pip -q install -r requirements.txt
!pip -q install sqlglot peft scikit-learn

In [None]:
from pathlib import Path
import json
import os
import time
import re
import gc
import math
import unicodedata
from decimal import Decimal
from datetime import date, datetime

import duckdb
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel, PeftConfig

try:
    import sqlglot
except:
    sqlglot = None

# Try importing local modules if available
import sys
sys.path.append(os.getcwd())
try:
    from research_pipeline.rag_retriever import TextToSQLRetriever
    print("RAG module loaded.")
except ImportError:
    TextToSQLRetriever = None
    print("RAG module NOT found.")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_4BIT = torch.cuda.is_available()
print(f"Device: {DEVICE}")

## Configuration

In [None]:
# ========== PATHS ==========
def find_repo_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "research_pipeline").exists():
            return p
    return start

REPO_ROOT = find_repo_root(Path.cwd())
DB_PATH = REPO_ROOT / "research_pipeline" / "data" / "ecommerce_dw.duckdb"
TEST_DATA_PATH = REPO_ROOT / "research_pipeline" / "data" / "test.csv"
DB_CONTENT_PATH = REPO_ROOT / "research_pipeline" / "data" / "db_content_samples.json"
RAG_INDEX_DIR = REPO_ROOT / "research_pipeline" / "rag_index"
RESULTS_DIR = REPO_ROOT / "research_pipeline" / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# ========== MODEL ==========
BASE_ID = "Qwen/Qwen3-4B-Instruct-2507"
ADAPTER_ID = "Ellbendls/Qwen-3-4b-Text_to_SQL"
LOCAL_ADAPTER = REPO_ROOT / "research_pipeline" / "adapters" / "qwen_lora_v1"

# ========== BENCHMARK CONFIG ==========
MAX_SAMPLES = None
MAX_NEW_TOKENS = 256
NUM_BEAMS = 1
MAX_TABLES = 8

# ========== CONDITIONS TO RUN ==========
RUN_CONDITIONS = ["B0", "B1", "B2", "B3", "B4", "B5"]

print(f"Repo: {REPO_ROOT}")
print(f"Test data: {TEST_DATA_PATH}")

## Setup Database & Load Resources

In [None]:
def setup_tpcds_db(db_path: Path, scale_factor: int = 1) -> None:
    db_path.parent.mkdir(parents=True, exist_ok=True)
    con = duckdb.connect(str(db_path))
    try:
        con.execute("INSTALL tpcds;")
        con.execute("LOAD tpcds;")
        tables = [r[0] for r in con.execute("SHOW TABLES").fetchall()]
        if not tables:
            print(f"Generating TPC-DS (sf={scale_factor})...")
            con.execute(f"CALL dsdgen(sf={scale_factor});")
        print(f"Database ready: {len(tables)} tables")
    finally:
        con.close()

if not DB_PATH.exists():
    setup_tpcds_db(DB_PATH)

con = duckdb.connect(str(DB_PATH), read_only=True)

# Build schema map
schema_map = {}
for (table_name,) in con.execute("SHOW TABLES").fetchall():
    cols = [(r[0], r[1]) for r in con.execute(f"DESCRIBE {table_name}").fetchall()]
    schema_map[table_name] = cols

# Load DB Content Samples (for B4)
db_content = {}
if DB_CONTENT_PATH.exists():
    with open(DB_CONTENT_PATH, "r", encoding="utf-8") as f:
        db_content = json.load(f)
    print(f"Loaded content samples for {len(db_content)} tables")

# Load RAG Retriever (for B5)
retriever = None
if TextToSQLRetriever and RAG_INDEX_DIR.exists():
    try:
        retriever = TextToSQLRetriever.load(RAG_INDEX_DIR)
        print("RAG Retriever loaded successfully")
    except Exception as e:
        print(f"Failed to load RAG Retriever: {e}")

## Schema Selection Utilities

In [None]:
def strip_accents(text: str) -> str:
    return "".join(ch for ch in unicodedata.normalize("NFD", text) if unicodedata.category(ch) != "Mn")

def tokenize(text: str) -> list[str]:
    text = strip_accents(text.lower())
    tokens = []
    for tok in re.findall(r"[a-z0-9_]+", text):
        tokens.extend(tok.split("_"))
    return [t for t in tokens if len(t) > 1]

SYNONYMS = {
    "khach": "customer", "khachhang": "customer", "sanpham": "item",
    "hang": "item", "danhmuc": "category", "bang": "state", "tinh": "state",
    "cuahang": "store", "doanhthu": "revenue", "soluong": "quantity",
    "gia": "price", "thang": "month", "nam": "year", "quy": "quarter",
}

def expand_tokens(tokens: list[str]) -> set[str]:
    expanded = set(tokens)
    for tok in tokens:
        if tok in SYNONYMS:
            expanded.add(SYNONYMS[tok])
    return expanded

# Pre-compute table tokens
table_tokens = {}
for table, cols in schema_map.items():
    tokens = set(tokenize(table))
    for col, _ in cols:
        tokens.update(tokenize(col))
    table_tokens[table] = tokens

def select_tables(question: str, max_tables: int = 8) -> list[str]:
    """Dynamic schema selection based on question tokens."""
    q_tokens = expand_tokens(tokenize(question))
    scored = [(len(q_tokens & tokens), table) for table, tokens in table_tokens.items()]
    scored.sort(reverse=True)
    selected = [t for score, t in scored if score > 0][:max_tables]
    
    def ensure(table):
        if table in schema_map and table not in selected:
            selected.append(table)
    
    if any(t in q_tokens for t in {"year", "month", "quarter", "date"}):
        ensure("date_dim")
    if "customer" in q_tokens:
        ensure("customer"); ensure("customer_address")
    if "store" in q_tokens:
        ensure("store_sales"); ensure("store")
    if any(t in q_tokens for t in {"sales", "revenue"}):
        ensure("store_sales")
    
    return selected[:max_tables]

def build_schema_text(tables: list[str], with_content: bool = False) -> str:
    """Build schema prompt with optional content samples."""
    lines = []
    for table in tables:
        cols = schema_map.get(table, [])
        lines.append(f"TABLE {table} (")
        for col, typ in cols:
            extra = ""
            if with_content:
                samples = db_content.get(table, {}).get(col)
                if samples:
                    # Limit to 3 samples to save context
                    s_str = ", ".join(f"'{str(s)}'" for s in samples[:3])
                    extra = f" -- sample: [{s_str}]"
            lines.append(f"  {col} {typ}{extra}")
        lines.append(")")
        lines.append("")
    return "\n".join(lines).strip()

def build_full_schema() -> str:
    return build_schema_text(list(schema_map.keys()))

## SQL Utilities

In [None]:
_FORBIDDEN = re.compile(r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE)\b", re.I)

SYSTEM_PROMPT = (
    "You translate user questions into SQL for DuckDB (TPC-DS). "
    "Return only SQL, no markdown, no explanations."
)
REPAIR_PROMPT = (
    "Fix the broken SQL for DuckDB (TPC-DS). Return only corrected SQL."
)

def extract_sql(text: str) -> str:
    text = text.strip()
    m = re.search(r"```(?:sql)?\s*(.*?)```", text, re.I | re.S)
    if m:
        text = m.group(1).strip()
    if ";" in text:
        text = text.split(";", 1)[0].strip()
    return text

def is_valid_sql(sql: str) -> bool:
    s = re.sub(r"--.*$", "", sql, flags=re.M).strip()
    if not s or _FORBIDDEN.search(s):
        return False
    first = re.split(r"\s+", s, 1)[0].upper()
    return first in {"SELECT", "WITH"}

def normalize_sql(sql: str) -> str:
    if sqlglot:
        try:
            return sqlglot.parse_one(sql, read="duckdb").sql(dialect="duckdb", pretty=False)
        except:
            pass
    return re.sub(r"\s+", " ", sql.strip()).lower()

def normalize_value(v):
    if isinstance(v, float) and math.isnan(v):
        return "nan"
    if isinstance(v, float):
        return round(v, 6)
    if isinstance(v, Decimal):
        return float(round(v, 6))
    if isinstance(v, (datetime, date)):
        return v.isoformat()
    return v

def normalize_rows(rows, keep_order: bool):
    if rows is None:
        return None
    norm = [tuple(normalize_value(x) for x in row) for row in rows]
    return norm if keep_order else sorted(norm)

def has_order_by(sql: str) -> bool:
    return bool(re.search(r"\border\s+by\b", sql, re.I))

def run_sql(con, sql: str):
    try:
        return con.execute(sql).fetchall(), None
    except Exception as e:
        return None, str(e)

## Model Loading

In [None]:
def load_model(use_adapter: bool = False, adapter_path: str = None):
    """Load Qwen model with optional LoRA adapter."""
    quant_config = BitsAndBytesConfig(load_in_4bit=True) if USE_4BIT else None
    
    if use_adapter:
        adapter_id = adapter_path or ADAPTER_ID
        peft_config = PeftConfig.from_pretrained(adapter_id)
        base_id = peft_config.base_model_name_or_path or BASE_ID
        tokenizer = AutoTokenizer.from_pretrained(adapter_id, use_fast=True, trust_remote_code=True)
    else:
        base_id = BASE_ID
        tokenizer = AutoTokenizer.from_pretrained(base_id, use_fast=True, trust_remote_code=True)
    
    config = AutoConfig.from_pretrained(base_id, trust_remote_code=True)
    if len(tokenizer) != config.vocab_size:
        config.vocab_size = len(tokenizer)
    
    model_kwargs = dict(
        config=config,
        device_map="auto" if DEVICE == "cuda" else None,
        trust_remote_code=True,
        ignore_mismatched_sizes=True,
    )
    if quant_config:
        model_kwargs["quantization_config"] = quant_config
    
    model = AutoModelForCausalLM.from_pretrained(base_id, **model_kwargs)
    
    if len(tokenizer) != model.get_input_embeddings().weight.shape[0]:
        model.resize_token_embeddings(len(tokenizer))
    
    if use_adapter:
        model = PeftModel.from_pretrained(model, adapter_id)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model.eval()
    return tokenizer, model

def unload_model(model):
    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

## Generation Functions

In [None]:
def build_prompt(question: str, schema_text: str, tokenizer, examples: list = None) -> str:
    few_shot_text = ""
    if examples:
        few_shot_text = "HERE ARE SOME EXAMPLES:\n"
        for ex in examples:
            few_shot_text += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
    
    user = f"SCHEMA:\n{schema_text}\n\n{few_shot_text}QUESTION:\n{question}\n\nSQL:"
    
    if getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}],
            tokenize=False, add_generation_prompt=True
        )
    return f"{SYSTEM_PROMPT}\n\n{user}"

def build_repair_prompt(question: str, schema_text: str, bad_sql: str, error: str, tokenizer) -> str:
    user = f"SCHEMA:\n{schema_text}\n\nQUESTION:\n{question}\n\nBROKEN_SQL:\n{bad_sql}\n\nERROR:\n{error}\n\nFIXED_SQL:"
    if getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            [{"role": "system", "content": REPAIR_PROMPT}, {"role": "user", "content": user}],
            tokenize=False, add_generation_prompt=True
        )
    return f"{REPAIR_PROMPT}\n\n{user}"

def generate(prompt: str, tokenizer, model) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    with torch.no_grad():
        output_ids = model.generate(
            **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, num_beams=NUM_BEAMS,
            pad_token_id=tokenizer.eos_token_id
        )
    gen_ids = output_ids[0][inputs["input_ids"].shape[1]:]
    return extract_sql(tokenizer.decode(gen_ids, skip_special_tokens=True))

## Setup Test Data & Cache

In [None]:
test_df = pd.read_csv(TEST_DATA_PATH)
test_df = test_df.dropna(subset=["Transcription", "SQL Ground Truth"])

if MAX_SAMPLES:
    test_df = test_df.head(MAX_SAMPLES)

print(f"Test samples: {len(test_df)}")

# Build Ground Truth cache
gt_cache = {}
for idx, row in test_df.iterrows():
    gt_sql = row["SQL Ground Truth"].strip()
    gt_res, gt_err = run_sql(con, gt_sql)
    gt_cache[idx] = {
        "sql": gt_sql,
        "res": gt_res,
        "err": gt_err,
        "has_order": has_order_by(gt_sql),
        "norm_sorted": normalize_rows(gt_res, False) if not gt_err else None,
        "norm_ordered": normalize_rows(gt_res, True) if not gt_err else None,
    }

print(f"Ground truth cached: {len(gt_cache)}")

## Run Benchmarks

In [None]:
def run_condition(condition: str, tokenizer, model) -> pd.DataFrame:
    """Run benchmark for a single condition."""
    use_dynamic = condition in ["B1", "B3", "B4", "B5"]
    use_content = condition == "B4"
    use_rag = condition == "B5"
    use_repair = condition == "B3"
    
    full_schema = build_full_schema() if not use_dynamic else None
    results = []
    
    for idx, row in test_df.iterrows():
        question = row["Transcription"]
        gt = gt_cache[idx]
        
        # Schema
        if use_dynamic:
            tables = select_tables(question, MAX_TABLES)
            schema_text = build_schema_text(tables, with_content=use_content)
        else:
            tables = list(schema_map.keys())
            schema_text = full_schema
        
        # Few-shot examples (RAG)
        examples = None
        if use_rag and retriever:
            examples = retriever.retrieve(question, k=3)
            
        # Generate
        start = time.time()
        prompt = build_prompt(question, schema_text, tokenizer, examples)
        gen_sql = generate(prompt, tokenizer, model)
        gen_time = (time.time() - start) * 1000
        
        valid = is_valid_sql(gen_sql)
        gen_res, gen_err = run_sql(con, gen_sql) if valid else (None, "INVALID_SQL")
        
        # Repair (B3)
        repair_used = False
        if use_repair and (not valid or gen_err):
            repair_prompt = build_repair_prompt(question, schema_text, gen_sql, gen_err or "Invalid", tokenizer)
            gen_sql = generate(repair_prompt, tokenizer, model)
            valid = is_valid_sql(gen_sql)
            gen_res, gen_err = run_sql(con, gen_sql) if valid else (None, "INVALID_SQL")
            repair_used = True
        
        # Metric Check
        exact = normalize_sql(gen_sql) == normalize_sql(gt["sql"]) if valid and not gen_err else False
        exec_match = False
        if valid and not gen_err and not gt["err"]:
            keep = gt["has_order"] or has_order_by(gen_sql)
            gt_norm = gt["norm_ordered"] if keep else gt["norm_sorted"]
            gen_norm = normalize_rows(gen_res, keep)
            exec_match = gt_norm == gen_norm
            
        results.append({
            "id": row.get("ID", idx),
            "condition": condition,
            "valid_sql": valid,
            "exact_match": exact,
            "exec_match": exec_match,
            "gen_time_ms": gen_time,
            "repair_used": repair_used
        })
        
        if (idx + 1) % 20 == 0:
            print(f"  [{condition}] Processed {idx + 1}/{len(test_df)}")
            
    return pd.DataFrame(results)

In [None]:
all_results = []

for cond in RUN_CONDITIONS:
    print(f"\n{'='*40}\nRunning: {cond}\n{'='*40}")
    
    # Logic to choose model based on condition intent
    # For simplicity: B0,B1 use Base; B2,B3,B4,B5 use Adapter (assuming we want to test enriched/RAG on top of fine-tuned)
    # OR: B4, B5 on Base model to see improvement without finetuning? 
    # Let's assume user wants to apply B4, B5 on Base or Finetuned? 
    # Usually techniques are tested on Strongest Baseline (Finetuned).
    # But if Finetuned is too good, we might not see delta. 
    # Let's standardize: B0, B1 = Base. B2, B3, B4, B5 = Finetuned.
    
    use_adapter = cond in ["B2", "B3", "B4", "B5"]
    adapter_path = str(LOCAL_ADAPTER) if LOCAL_ADAPTER.exists() else ADAPTER_ID
    
    print(f"Loading model (adapter={use_adapter})...")
    tokenizer, model = load_model(use_adapter=use_adapter, adapter_path=adapter_path if use_adapter else None)
    
    results_df = run_condition(cond, tokenizer, model)
    results_df.to_csv(RESULTS_DIR / f"{cond}_results.csv", index=False)
    all_results.append(results_df)
    
    unload_model(model)

In [None]:
summary_rows = []
for df in all_results:
    c = df["condition"].iloc[0]
    summary_rows.append({
        "Condition": c,
        "Valid (%)": f"{df['valid_sql'].mean()*100:.1f}",
        "Exact Match (%)": f"{df['exact_match'].mean()*100:.1f}",
        "Exec Acc (%)": f"{df['exec_match'].mean()*100:.1f}",
        "Latency (ms)": f"{df['gen_time_ms'].mean():.0f}"
    })

sum_df = pd.DataFrame(summary_rows)
sum_df.to_csv(RESULTS_DIR / "summary.csv", index=False)
print(sum_df.to_string(index=False))

In [None]:
con.close()
print("Done!")