# Agentic Evaluation (ReAct-style)

Refs/Docs: `REFERENCES.md#ref-yao2023-react`, `REFERENCES.md#ref-zhai2025-excot`, `REFERENCES.md#ref-zhong2020-ts`, `REFERENCES.md#ref-yu2018-spider`. Docs: HF Transformers, Cloud SQL Connector, SQLAlchemy.

Say: This notebook adds a bounded ReAct-style loop on top of the ClassicModels benchmark and evaluates VA/EM/EX/TS.
Why: Execution feedback helps fix runnable-but-wrong SQL without changing model weights.

Plan (step-by-step):
1) Setup runtime
2) DB connection
3) Schema + test set
4) Model load
5) ReAct agent (`nl2sql/agent.py`)
6) Evaluate + save JSON

Code pointers: `nl2sql/agent.py`, `nl2sql/eval.py`, `nl2sql/query_runner.py`.


Refs/Docs (quick list):
- HF Transformers quantization docs
- PEFT + TRL docs
- Cloud SQL Connector docs
- SQLAlchemy engine/execute docs
- ReAct paper `REFERENCES.md#ref-yao2023-react`


## Setup (run first, then restart)

Docs: HF Transformers quantization docs; BitsAndBytes docs.
Say: One-time environment pinning so results are reproducible.
Why: Small version drift changes generation and metrics.


Docs: HF Transformers quantization; BitsAndBytes (4-bit) install guidance.


Docs: HF/BnB version pinning guidance (runtime stability).
Say: Pin versions so results do not drift across runs.
Why: Minor library changes can shift VA/EX/TS.
Code pointers: `requirements.txt`.


In [None]:

%%bash
set -e
export PIP_DEFAULT_TIMEOUT=120

# Clean conflicting preinstalls
pip uninstall -y torch torchvision torchaudio bitsandbytes triton transformers accelerate peft trl datasets numpy pandas fsspec requests google-auth || true

# Base deps
pip install -q --no-cache-dir --force-reinstall   numpy==1.26.4 pandas==2.2.1 fsspec==2024.5.0 requests==2.31.0 google-auth==2.43.0

# Torch + CUDA 12.1
pip install -q --no-cache-dir --force-reinstall   torch==2.3.1+cu121 torchvision==0.18.1+cu121 torchaudio==2.3.1+cu121   --index-url https://download.pytorch.org/whl/cu121

# bitsandbytes + triton + HF stack
pip install -q --no-cache-dir --force-reinstall   bitsandbytes==0.43.3 triton==2.3.1   transformers==4.44.2 accelerate==0.33.0 peft==0.17.0 trl==0.9.6 datasets==2.20.0

echo "Setup complete. Restart runtime once, then run the rest of the notebook top-to-bottom."


Model load: HF 4-bit NF4 + optional PEFT adapters; baseline decoding is deterministic.
Refs: `REFERENCES.md#ref-ding2023-peft`, `REFERENCES.md#ref-goswami2024-peft`.
Docs: HF quantization, PEFT.
Code pointers: `scripts/run_full_pipeline.py`, `nl2sql/llm.py`.


Docs: Colab + git clone workflow.
Say: Clone the repo so we reuse the same `nl2sql/` modules as the scripts.
Why: Keeps logic in code, not hidden in the notebook.
Code pointers: `nl2sql/`, `scripts/run_full_pipeline.py`.


In [None]:
# 0) Clone repo (Colab) + install deps
import os
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    if not os.path.exists('/content/NLtoSQL'):
        !git clone https://github.com/MacKenzieOBrian/NLtoSQL.git /content/NLtoSQL
    %cd /content/NLtoSQL
    !pip -q install -r requirements.txt
    import torch, transformers, accelerate, peft
    print('torch', torch.__version__, 'cuda', torch.cuda.is_available())
else:
    print('Not in Colab; using existing workspace')


Prompt/eval: build prompts (system + schema + k exemplars), generate SQL, postprocess, then compute VA/EM/EX.
Refs: `REFERENCES.md#ref-brown2020-gpt3`, `REFERENCES.md#ref-mosbach2023-icl`, `REFERENCES.md#ref-zhong2020-ts`.
Code pointers: `nl2sql/prompting.py`, `nl2sql/llm.py`, `nl2sql/postprocess.py`, `nl2sql/eval.py`.


Docs: Standard Colab clone/install pattern.
Say: Keep the notebook thin and reuse code from `nl2sql/`.


### Reference notes (what this builds on)

Refs: `REFERENCES.md#ref-zhu2024-survey`, `REFERENCES.md#ref-hong2025-survey`, `REFERENCES.md#ref-yao2023-react`, `REFERENCES.md#ref-ojuri2025-agents`.
Docs: Cloud SQL Connector, SQLAlchemy creator pattern, HF quantization.
Code pointers: `nl2sql/db.py`, `nl2sql/schema.py`, `nl2sql/agent.py`, `nl2sql/eval.py`.


## Optional: use gcloud ADC (without a key)

Docs: GCP ADC (Application Default Credentials).
Say: Use ADC if you do not want to pass a service account JSON file.


Docs: GCP ADC auth flow (Application Default Credentials).


Docs: GCP ADC auth flow.
Say: Optional ADC auth for Cloud SQL (no key file required).
Code pointers: `nl2sql/db.py` (`create_engine_with_connector`).


In [None]:
# Run this only if you prefer gcloud-based ADC (no JSON key)
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    %pip install -q --upgrade google-auth google-auth-oauthlib
    !gcloud auth application-default login
else:
    print("Not in Colab; skip gcloud auth.")


Refs/Docs: Pinned CUDA + bitsandbytes/triton stack per HF/BnB guidance for 4-bit loads.


Docs: Cloud SQL Connector + SQLAlchemy creator pattern (MySQL).


Docs: Cloud SQL Connector + SQLAlchemy `creator=` hook.


Docs: SQLAlchemy execute + Cloud SQL Connector.
Say: Create the base DB engine and a SELECT-only QueryRunner (the Act step).
Why: QueryRunner gives VA (executability) and blocks destructive SQL.
Code pointers: `nl2sql/db.py`, `nl2sql/query_runner.py`.


In [None]:
# 1) Environment + DB
import os
from getpass import getpass

from sqlalchemy import text

from nl2sql.db import create_engine_with_connector, safe_connection

# Expected env vars (set these in a Colab cell):
# INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME
INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME") or "classicmodels"

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

# Canonical engine builder (shared with scripts + other notebooks).
# Uses Cloud SQL Connector under the hood and ADC for credentials.
engine, connector = create_engine_with_connector(
    instance_connection_name=INSTANCE_CONNECTION_NAME,
    user=DB_USER,
    password=DB_PASS,
    db_name=DB_NAME,
)

with safe_connection(engine) as conn:
    conn.execute(text("SELECT 1"))
print("DB connection OK")


Say: Create an engine factory for TS replica DBs.
Why: TS compares gold vs pred across perturbed DBs to avoid lucky execution.
Refs: `REFERENCES.md#ref-zhong2020-ts`.
Code pointers: this cell (`make_engine`), `nl2sql/eval.py:test_suite_accuracy_for_item`.


In [None]:
# 1b) Engine factory for TS (multiple DB names)

import sqlalchemy
from sqlalchemy.engine import Engine


def make_engine(db_name: str) -> Engine:
    """Create a new engine bound to a specific TS replica DB name.

    TS (test-suite accuracy) executes the same (gold, pred) SQL across multiple
    replica databases (classicmodels_ts_XX). We keep separate engines so each
    replica is evaluated independently.
    """

    def getconn_for_db():
        return connector.connect(
            INSTANCE_CONNECTION_NAME,
            "pymysql",
            user=DB_USER,
            password=DB_PASS,
            db=db_name,
        )

    return sqlalchemy.create_engine("mysql+pymysql://", creator=getconn_for_db, future=True)


Refs: schema-grounded prompting surveys `REFERENCES.md#ref-zhu2024-survey`, `REFERENCES.md#ref-hong2025-survey`.
Code pointers: `nl2sql/schema.py`.


Docs: schema-grounded prompting (survey context); Spider-style schema summaries.


Refs: `REFERENCES.md#ref-zhu2024-survey`, `REFERENCES.md#ref-hong2025-survey`.
Say: Build schema text and load a small debug slice.
Why: Smaller runs let you iterate on postprocess + agent logic quickly.
Code pointers: `nl2sql/schema.py:build_schema_summary`, `data/classicmodels_test_200.json`.


In [None]:
# 2) Load schema summary + test set (small slice for now)
import json
from nl2sql.schema import build_schema_summary

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)

test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
# default to a small slice while debugging
test_set = full_set[:5]
print("Demo items:", len(test_set))
# For full run, switch to: test_set = full_set; print("Test items:", len(test_set))

TABLES = {line.split('(', 1)[0].strip() for line in SCHEMA_SUMMARY.splitlines() if '(' in line}
TABLES_LOWER = {t.lower(): t for t in TABLES}


Refs/Docs: HF 4-bit NF4 + BitsAndBytes; adapters via PEFT.


Docs: HF quantization + PEFT/QLoRA load patterns.


Say: Load the base model (4-bit) and optional PEFT adapters.
Why: Baseline decoding is deterministic; agent sampling is controlled by config.
Refs: `REFERENCES.md#ref-ding2023-peft`.
Code pointers: `nl2sql/llm.py`, `scripts/run_full_pipeline.py`.


In [None]:

# 3) Load model (base or QLoRA adapters)
import os
from getpass import getpass
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = os.getenv("ADAPTER_PATH") or "results/adapters/qlora_classicmodels"  # set to None to use base model

HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    HF_TOKEN = getpass("Enter HF_TOKEN (https://huggingface.co/settings/tokens): ").strip()

cc_major, cc_minor = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else (0, 0)
use_bf16 = cc_major >= 8
compute_dtype = torch.bfloat16 if use_bf16 else torch.float16
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Using bf16:", use_bf16)
print("Adapter path:", ADAPTER_PATH)

# Tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Quantized base model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
    device_map={"": 0} if torch.cuda.is_available() else None,
    token=HF_TOKEN,
)
base_model.generation_config.do_sample = False
base_model.generation_config.temperature = 1.0
base_model.generation_config.top_p = 1.0

# Load adapters if present locally; otherwise use base model
adapter_dir = Path(ADAPTER_PATH) if ADAPTER_PATH else None
if adapter_dir and adapter_dir.exists():
    model = PeftModel.from_pretrained(base_model, adapter_dir, token=HF_TOKEN)
    print("Loaded adapters from", adapter_dir)
else:
    print("Adapter path missing; using base model only. Set ADAPTER_PATH to your local adapter folder or upload it to Colab.")
    model = base_model


## Optional adapter sanity check (run before ReAct)
Say: Quick test to confirm the model/adapters can produce executable SQL.
Why: Catch obvious issues before the agent loop.
Refs: `REFERENCES.md#ref-mosbach2023-icl`.


Refs/Docs: ICL patterns `REFERENCES.md#ref-brown2020-gpt3`, evaluation limits `REFERENCES.md#ref-zhong2020-ts`.


Refs: `REFERENCES.md#ref-brown2020-gpt3`, `REFERENCES.md#ref-zhong2020-ts`.
Say: Quick end-to-end check: prompt -> generate -> postprocess -> execute -> compare.
Why: Detect early failures before full runs.
Code pointers: `nl2sql/prompting.py`, `nl2sql/postprocess.py`, `nl2sql/query_runner.py`, `nl2sql/eval.py`.


In [None]:
from nl2sql.prompting import make_few_shot_messages
from nl2sql.llm import extract_first_select
from nl2sql.postprocess import guarded_postprocess
from nl2sql.query_runner import QueryRunner
from nl2sql.eval import execution_accuracy

runner_check = QueryRunner(engine)
# reuse existing test_set (default small slice); pick 3 exemplars
exemplars = test_set[:3]

def run_quick_check(k: int = 0, limit: int = 3):
    print(f"Quick check k={k}")
    for sample in test_set[:limit]:
        shots = exemplars if k > 0 else []
        msgs = make_few_shot_messages(
            schema=SCHEMA_SUMMARY,
            exemplars=shots,
            nlq=sample['nlq'],
        )
        prompt_preview = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        inputs = tok(prompt_preview, return_tensors="pt").to(model.device)
        out = model.generate(**inputs, max_new_tokens=256, do_sample=False)

        # strip the prompt before decoding the generation
        gen_ids = out[0][inputs.input_ids.shape[-1]:]
        text = tok.decode(gen_ids, skip_special_tokens=True)

        raw_sql = extract_first_select(text) or text
        sql = guarded_postprocess(raw_sql, sample['nlq'])

        meta = runner_check.run(sql, capture_df=False)
        va = meta.success
        ex_ok, _, _ = execution_accuracy(engine=engine, pred_sql=sql, gold_sql=sample['sql'])
        print(f"Q: {sample['nlq']}\nSQL: {sql}\nVA: {va} EX: {ex_ok}\n")

run_quick_check(k=0)
run_quick_check(k=3)


Refs: ReAct pattern `REFERENCES.md#ref-yao2023-react` applied to NL to SQL.


Docs: ReAct paper + SELECT-only execution guard.


Refs: `REFERENCES.md#ref-zhu2024-survey`, `REFERENCES.md#ref-hong2025-survey`.
Say: Import the shared heuristics used by the agent.
Why: Keeps policy (intent checks, scoring) in `nl2sql/` not in the notebook.
Code pointers: `nl2sql/agent_utils.py`.


In [None]:
# Helper imports (optional; used for interactive inspection)
# Main agent loop is in `nl2sql/agent.py`.
from nl2sql.agent_utils import intent_constraints, semantic_score, count_select_columns


### Agent status (for dissertation)

Say: The current loop is a bounded ReAct-style agent (generate -> execute -> observe -> repair/fallback).
Why: It improves correctness without changing model weights; all changes are deterministic or gated.
Refs: `REFERENCES.md#ref-yao2023-react`, `REFERENCES.md#ref-zhai2025-excot`.
Code pointers: `nl2sql/agent.py`.


Refs: execution-based evaluation `REFERENCES.md#ref-ojuri2025-agents`, TS motivation `REFERENCES.md#ref-zhong2020-ts`.


Refs/Docs: ICL patterns `REFERENCES.md#ref-brown2020-gpt3`, evaluation limits `REFERENCES.md#ref-zhong2020-ts`.


## ReAct execution-guided pipeline (current version)

Refs: `REFERENCES.md#ref-yao2023-react`, `REFERENCES.md#ref-zhai2025-excot`.
Say: These cells configure the canonical agent and evaluation harness.
Code pointers: `nl2sql/agent.py`, `nl2sql/eval.py`.


## Reference Map (Code <-> Literature)

Refs: `REFERENCES.md#ref-yao2023-react`, `REFERENCES.md#ref-zhai2025-excot`, `REFERENCES.md#ref-scholak2021-picard`, `REFERENCES.md#ref-zhu2024-survey`, `REFERENCES.md#ref-li2023-bigbench`, `REFERENCES.md#ref-zhong2020-ts`.

- Execution feedback + repair
- Output constraints / validity
- Projection control + schema linking context
- Semantic evaluation via test suites

Code pointers: `nl2sql/agent.py`, `nl2sql/postprocess.py`, `nl2sql/agent_utils.py`, `nl2sql/eval.py`.


Refs: `REFERENCES.md#ref-yu2018-spider`.
Say: Reload schema + full test set and create the runner for real evaluation.
Code pointers: `nl2sql/schema.py`, `nl2sql/query_runner.py`.


In [None]:
# 4) Schema summary + test set + QueryRunner
import json
from pathlib import Path
from nl2sql.schema import build_schema_summary
from nl2sql.query_runner import QueryRunner

DB_NAME = os.getenv("DB_NAME") or "classicmodels"

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)
# Schema summary is used in prompts to ground column/table choices.
test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
test_set = full_set  # change to full_set[:20] when debugging

print("Loaded test set size:", len(test_set))
runner = QueryRunner(engine)  # QueryRunner enforces SELECT-only execution and records errors for VA/EX.


Docs: notebook rerun hygiene.
Say: Defensive re-import so the following cells run cleanly after partial reruns.
Code pointers: `nl2sql/agent.py`, `nl2sql/agent_utils.py`, `nl2sql/postprocess.py`.


In [None]:
# 5) Agent utilities (used inside `nl2sql/agent.py`)
from nl2sql.agent_utils import intent_constraints, semantic_score, count_select_columns, vanilla_candidate


## 6. Agent Implementation (Module-Based, Explainable ReAct Loop)

Say: The agent lives in code so the notebook is just configuration + evaluation.
Why: One source of truth makes the loop easy to defend.
Refs: `REFERENCES.md#ref-yao2023-react`, `REFERENCES.md#ref-zhai2025-excot`.

What the loop does (high level):
- Build prompts (ReAct + optional tabular)
- Generate a small candidate set (bounded)
- Clean + deterministic postprocess
- Execute gate (SELECT-only)
- Intent gate
- Score and pick best
- Optional repair using DB error message
- Fallback to deterministic baseline if needed

Code pointers:
- `nl2sql/agent.py` (`ReactSqlAgent.react_sql`, `evaluate_candidate`, `repair_sql`)
- `nl2sql/postprocess.py` (`guarded_postprocess`)
- `nl2sql/agent_utils.py` (`clean_candidate_with_reason`, `intent_constraints`, `semantic_score`, `count_select_columns`, `enforce_projection_contract`)
- `nl2sql/query_runner.py` (`QueryRunner.run`)


Refs: `REFERENCES.md#ref-yao2023-react`.
Say: Configure the agent in one place (`ReactConfig`).
Why: Bounded steps/candidates/repair makes the loop auditable.
Code pointers: `nl2sql/agent.py` (`ReactConfig`, `ReactSqlAgent`).


In [None]:
# 6) Agent implementation (imported)

from nl2sql.agent import ReactConfig, ReactSqlAgent

# Keep config explicit so it is easy to justify in a viva (bounded steps/candidates/repair).
CFG = ReactConfig(
    max_steps=3,
    num_cands=6,
    do_sample=True,
    temperature=0.5,
    top_p=0.9,
    max_new_tokens=128,
    enable_repair=True,
    repair_num_cands=4,
    use_tabular_prompt=True,
    use_schema_subset=True,
    use_projection_contract=True,
)

agent = ReactSqlAgent(model=model, tok=tok, runner=runner, cfg=CFG)
# Preserve the old function name used later in the notebook.
react_sql = agent.react_sql


Refs: `REFERENCES.md#ref-yao2023-react`.
Say: Run the bounded ReAct loop and capture a trace you can explain.
What to say out loud:
- "The agent proposes candidates, executes them, and uses the DB response as feedback."
- "Candidates must pass execution + intent gates before scoring."
- "Each rejection is logged with a reason."

Code pointers: `nl2sql/agent.py:ReactSqlAgent.react_sql`, `nl2sql/query_runner.py:QueryRunner.run`, `nl2sql/agent_utils.py` (gates + scoring).


In [None]:
# 7) ReAct loop
# The loop lives in `nl2sql/agent.py` (ReactSqlAgent.react_sql).
# This notebook calls `react_sql(...)` for quick checks and full evaluation.


## EX Troubleshooting Checklist

Say: If VA is high but EX is low, it is usually a semantic mismatch.
Refs: `REFERENCES.md#ref-zhong2020-ts`.

Quick checks:
- Projection drift: tighten `enforce_projection_contract`
- Wrong intent: check `intent_constraints`
- Wrong table/join: verify schema-subset prompt
- Missing literals: make sure filters appear in SQL

Code pointers: `nl2sql/agent_utils.py`, `nl2sql/postprocess.py`, `nl2sql/eval.py`.


Refs: `REFERENCES.md#ref-zhong2020-ts`.
Say: Manual spot-checks before full eval.
Why: Inspect trace and errors before running TS across many DBs.
Code pointers: this cell prints the `trace` from `react_sql`.


In [None]:
# 8) Quick sanity check on a few items
schema_text = SCHEMA_SUMMARY
for sample in test_set[:5]:
    nlq = sample["nlq"]
    gold = sample["sql"]
    pred, trace = react_sql(
        nlq=nlq,
        schema_text=schema_text,
        schema_summary=SCHEMA_SUMMARY,
        exemplars=test_set[:3],
    )
    print("NLQ:", nlq)
    print("PRED:", pred)
    print("GOLD:", gold)
    print("TRACE LEN:", len(trace))
    print("-" * 80)


### Stage 3 Interpretation (29 Jan 2026)

Refs: `REFERENCES.md#ref-zhong2020-ts`.
Say: Current loop usually returns executable SQL; remaining issues are projection bloat and spurious ORDER/GROUP BY.
Why: These hurt EM more than EX; use deterministic guards to stabilize.
Code pointers: `nl2sql/postprocess.py`, `nl2sql/agent_utils.py`.


## Run Order (recommended)

Docs: notebook run order (practical).
Say: Follow this order to avoid DB/auth issues and wasted TS runs.

1) Runtime + deps (restart after install)
2) Cloud SQL connector + base engine
3) TS engine factory (`make_engine`)
4) Schema + dataset
5) Agent config + loop
6) TS harness
7) Evaluation


Say: Import the TS evaluator (semantic robustness across replica DBs).
Refs: `REFERENCES.md#ref-zhong2020-ts`.
Code pointers: `nl2sql/eval.py:test_suite_accuracy_for_item`.


In [None]:
# === Test Suite Accuracy (TS) evaluation ===
# Harness now lives in nl2sql.eval for reuse in scripts.
from nl2sql.eval import test_suite_accuracy_for_item


Refs: `REFERENCES.md#ref-zhong2020-ts`.
Say: Cost guards so TS/EX are fast while debugging.
Why: TS multiplies cost by N replica DBs.
Code pointers: `nl2sql/eval.py` (row caps + suite size).


In [None]:
# === Quick test toggles (set before full eval) ===
# Use small values to sanityâ€‘check TS/EX before full runs.
QUICK_LIMIT = 20   # number of NLQs to evaluate (set None for full set)
TS_N = 3           # number of TS DBs (set 10 for full TS)
MAX_ROWS_TS = 500  # row cap per query in TS (raise for full)


Refs: `REFERENCES.md#ref-zhong2020-ts`, `REFERENCES.md#ref-yu2018-spider`.
Say: Full evaluation loop (VA/EM/EX/TS) + save results JSON.
Why: Produces per-item metrics and aggregate rates for the dissertation.
Code pointers: this cell, `nl2sql/eval.py`, `nl2sql/query_runner.py`.


In [None]:
# 9) Full ReAct-style evaluation (VA/EX/EM/TS) over test_set

import json
from functools import lru_cache
from pathlib import Path

from sqlalchemy.engine import Engine

from nl2sql.eval import execution_accuracy, test_suite_accuracy_for_item
from nl2sql.postprocess import normalize_sql

results = []

TS_PREFIX = "classicmodels_ts"
SUITE_DBS = [f"{TS_PREFIX}_{i:02d}" for i in range(1, TS_N + 1)]


@lru_cache(maxsize=32)
def make_engine_cached(db_name: str) -> Engine:
    return make_engine(db_name)


def make_engine_fn(db_name: str) -> Engine:
    return make_engine_cached(db_name)


LIMIT = QUICK_LIMIT  # override from quick toggles
items = test_set[:LIMIT] if LIMIT else test_set
schema_text = SCHEMA_SUMMARY

# Per-item evaluation: generate SQL and compute VA/EM/EX/TS.
for i, sample in enumerate(items, start=1):
    nlq = sample["nlq"]
    gold_sql = sample["sql"]

    pred_sql, trace = react_sql(
        nlq=nlq,
        schema_text=schema_text,
        schema_summary=SCHEMA_SUMMARY,
        exemplars=test_set[:3],
    )

    # EM is strict (normalized) string match; kept as a diagnostic signal.
    em = int(normalize_sql(pred_sql) == normalize_sql(gold_sql))

    # VA = executability on base DB.
    meta = runner.run(pred_sql, capture_df=False)
    va = int(meta.success)

    # EX = result equivalence on base DB (only meaningful if VA=1).
    ex = 0
    ex_pred_err = None
    ex_gold_err = None
    if va:
        ex_ok, ex_pred_err, ex_gold_err = execution_accuracy(
            engine=engine,
            pred_sql=pred_sql,
            gold_sql=gold_sql,
        )
        ex = int(ex_ok)

    # TS is expensive (runs across N replica DBs). Skip if pred_sql does not
    # execute on the base DB (VA=0) because TS would be 0 anyway.
    ts = 0
    if not va:
        ts_debug = {"skipped": True, "reason": "va=0", "error": meta.error}
    else:
        ts, ts_debug = test_suite_accuracy_for_item(
            make_engine_fn=make_engine_fn,
            suite_db_names=SUITE_DBS,
            gold_sql=gold_sql,
            pred_sql=pred_sql,
            max_rows=MAX_ROWS_TS,
            strict_gold=True,
        )

    results.append(
        {
            "nlq": nlq,
            "gold_sql": gold_sql,
            "pred_sql": pred_sql,
            "va": va,
            "em": em,
            "ex": ex,
            "ts": ts,
            "error": meta.error or ex_pred_err,
            "gold_error": ex_gold_err,
            "ts_debug": ts_debug,
            "trace": trace,
        }
    )

    if i % 10 == 0:
        print(f"Processed {i}/{len(items)}")

va_rate = sum(r["va"] for r in results) / max(len(results), 1)
ex_rate = sum(r["ex"] for r in results) / max(len(results), 1)
em_rate = sum(r["em"] for r in results) / max(len(results), 1)
ts_rate = sum(r["ts"] for r in results) / max(len(results), 1)
print("ReAct VA:", va_rate, "EX:", ex_rate, "EM:", em_rate, "TS:", ts_rate)

Path("results/agent").mkdir(parents=True, exist_ok=True)
save_path = Path("results/agent/results_react_200.json")
save_path.write_text(
    json.dumps(
        {
            "va_rate": va_rate,
            "ex_rate": ex_rate,
            "em_rate": em_rate,
            "ts_rate": ts_rate,
            "items": results,
        },
        ensure_ascii=False,
        indent=2,
    ),
    encoding="utf-8",
)
print("Saved to", save_path)
