# Agentic Evaluation (Tool-Driven ReAct Loop)

**Core code**:
- `nl2sql/agent_tools.py`, `nl2sql/prompts.py`, `nl2sql/eval.py`
Refs: `REFERENCES.md#ref-yao2023-react`, `REFERENCES.md#ref-zhai2025-excot`, `REFERENCES.md#ref-zhong2020-ts`, `REFERENCES.md#ref-yu2018-spider`


### Install dependencies (pinned)

**What this cell does**: installs the exact versions used for reported metrics.

**Explain with**: `requirements.txt`


In [None]:

%%bash
set -e
export PIP_DEFAULT_TIMEOUT=120

# Clean conflicting preinstalls
pip uninstall -y torch torchvision torchaudio bitsandbytes triton transformers accelerate peft trl datasets numpy pandas fsspec requests google-auth || true

# Base deps
pip install -q --no-cache-dir --force-reinstall   numpy==1.26.4 pandas==2.2.1 fsspec==2024.5.0 requests==2.31.0 google-auth==2.43.0

# Torch + CUDA 12.1
pip install -q --no-cache-dir --force-reinstall   torch==2.3.1+cu121 torchvision==0.18.1+cu121 torchaudio==2.3.1+cu121   --index-url https://download.pytorch.org/whl/cu121

# bitsandbytes + triton + HF stack
pip install -q --no-cache-dir --force-reinstall   bitsandbytes==0.43.3 triton==2.3.1   transformers==4.44.2 accelerate==0.33.0 peft==0.17.0 trl==0.9.6 datasets==2.20.0

echo "Setup complete. Restart runtime once, then run the rest of the notebook top-to-bottom."


### Sync repo into Colab

**What this cell does**: clones the repo so the notebook uses the same `nl2sql/` code as scripts.

**Explain with**: `context.md` (reproducibility summary)


In [None]:
# 0) Clone repo (Colab) + install deps
import os
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    if not os.path.exists('/content/NLtoSQL'):
        !git clone https://github.com/MacKenzieOBrian/NLtoSQL.git /content/NLtoSQL
    %cd /content/NLtoSQL
    !pip -q install -r requirements.txt
    import torch, transformers, accelerate, peft
    print('torch', torch.__version__, 'cuda', torch.cuda.is_available())
else:
    print('Not in Colab; using existing workspace')


###  ADC auth 

**What this cell does**: authenticates with gcloud ADC for Cloud SQL access.

**Explain with**: `nl2sql/db.py:create_engine_with_connector`


In [None]:
# Run this only if you prefer gcloud-based ADC (no JSON key)
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    %pip install -q --upgrade google-auth google-auth-oauthlib
    !gcloud auth application-default login
else:
    print("Not in Colab; skip gcloud auth.")


### Create DB engine + QueryRunner (the “Act” tool)

**What this cell does**: builds the SQLAlchemy engine and a SELECT‑only executor.

**Code**: `nl2sql/db.py`, `nl2sql/query_runner.py`


In [None]:
# 1) Environment + DB
import os
from getpass import getpass

from sqlalchemy import text

from nl2sql.db import create_engine_with_connector, safe_connection

# Expected env vars (set these in a Colab cell):
# INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME
INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME") or "classicmodels"

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

# Canonical engine builder (shared with scripts + other notebooks).
# Uses Cloud SQL Connector under the hood and ADC for credentials.
engine, connector = create_engine_with_connector(
    instance_connection_name=INSTANCE_CONNECTION_NAME,
    user=DB_USER,
    password=DB_PASS,
    db_name=DB_NAME,
)

with safe_connection(engine) as conn:
    conn.execute(text("SELECT 1"))
print("DB connection OK")


### TS engine factory (replica DBs)

**What this cell does**: creates engines for test‑suite replicas used in TS.

**Explain with**: `4_EVALUATION.md`, `REFERENCES.md#ref-zhong2020-ts`


In [None]:
# 1b) Engine factory for TS (multiple DB names)

import sqlalchemy
from sqlalchemy.engine import Engine


def make_engine(db_name: str) -> Engine:
    """Create a new engine bound to a specific TS replica DB name.

    TS (test-suite accuracy) executes the same (gold, pred) SQL across multiple
    replica databases (classicmodels_ts_XX). We keep separate engines so each
    replica is evaluated independently.
    """

    def getconn_for_db():
        return connector.connect(
            INSTANCE_CONNECTION_NAME,
            "pymysql",
            user=DB_USER,
            password=DB_PASS,
            db=db_name,
        )

    return sqlalchemy.create_engine("mysql+pymysql://", creator=getconn_for_db, future=True)


### Build schema summary + load test set

**What this cell does**: builds schema text for prompts and loads ClassicModels test queries.

**Explain with**: `nl2sql/schema.py:build_schema_summary`, `data/classicmodels_test_200.json`


In [None]:
# 2) Load schema summary + test set (small slice for now)
import json
from nl2sql.schema import build_schema_summary
SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)
print("Schema contains offices.city:", "offices" in SCHEMA_SUMMARY.lower() and "city" in SCHEMA_SUMMARY.lower())
test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
# default to a small slice while debugging
test_set = full_set[:5]
print("Demo items:", len(test_set))
# For full run, switch to: test_set = full_set; print("Test items:", len(test_set))
# Small exemplar set (taken from the test set) to improve join behavior.
join_exemplars = [it for it in full_set if "office" in it["nlq"].lower()]
REACT_EXEMPLARS = []
if join_exemplars:
    REACT_EXEMPLARS.append(join_exemplars[0])
for it in full_set:
    if it not in REACT_EXEMPLARS:
        REACT_EXEMPLARS.append(it)
    if len(REACT_EXEMPLARS) >= 3:
        break
print("Exemplars:", [e["nlq"] for e in REACT_EXEMPLARS])
TABLES = {line.split('(', 1)[0].strip() for line in SCHEMA_SUMMARY.splitlines() if '(' in line}
TABLES_LOWER = {t.lower(): t for t in TABLES}


### Load model (base + optional adapters)

**What this cell does**: loads the base model and attaches QLoRA adapters if provided.

**Explain with**: `1_LITERATURE.md` (PEFT), `2_METHODOLOGY.md`

**Code**: `nl2sql/llm.py`


In [None]:

# 3) Load model (base or QLoRA adapters)
import os
from getpass import getpass
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = os.getenv("ADAPTER_PATH") or "results/adapters/qlora_classicmodels"  # set to None to use base model

HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    HF_TOKEN = getpass("Enter HF_TOKEN (https://huggingface.co/settings/tokens): ").strip()

cc_major, cc_minor = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else (0, 0)
use_bf16 = cc_major >= 8
compute_dtype = torch.bfloat16 if use_bf16 else torch.float16
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Using bf16:", use_bf16)
print("Adapter path:", ADAPTER_PATH)

# Tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Quantized base model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
    device_map={"": 0} if torch.cuda.is_available() else None,
    token=HF_TOKEN,
)
base_model.generation_config.do_sample = False
base_model.generation_config.temperature = 1.0
base_model.generation_config.top_p = 1.0

# Load adapters if present locally; otherwise use base model
adapter_dir = Path(ADAPTER_PATH) if ADAPTER_PATH else None
if adapter_dir and adapter_dir.exists():
    model = PeftModel.from_pretrained(base_model, adapter_dir, token=HF_TOKEN)
    print("Loaded adapters from", adapter_dir)
else:
    print("Adapter path missing; using base model only. Set ADAPTER_PATH to your local adapter folder or upload it to Colab.")
    model = base_model


### Optional smoke check (baseline path)

**What this cell does**: runs a tiny end‑to‑end baseline pass to confirm generation + execution works.

**Explain with**: `nl2sql/prompting.py`, `nl2sql/postprocess.py`, `nl2sql/query_runner.py`


In [None]:
from nl2sql.prompting import make_few_shot_messages
from nl2sql.llm import extract_first_select
from nl2sql.postprocess import guarded_postprocess
from nl2sql.query_runner import QueryRunner
from nl2sql.eval import execution_accuracy

runner_check = QueryRunner(engine)
# reuse existing test_set (default small slice); pick 3 exemplars
exemplars = test_set[:3]

def run_quick_check(k: int = 0, limit: int = 3):
    print(f"Quick check k={k}")
    for sample in test_set[:limit]:
        shots = exemplars if k > 0 else []
        msgs = make_few_shot_messages(
            schema=SCHEMA_SUMMARY,
            exemplars=shots,
            nlq=sample['nlq'],
        )
        prompt_preview = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        inputs = tok(prompt_preview, return_tensors="pt").to(model.device)
        out = model.generate(**inputs, max_new_tokens=256, do_sample=False)

        # strip the prompt before decoding the generation
        gen_ids = out[0][inputs.input_ids.shape[-1]:]
        text = tok.decode(gen_ids, skip_special_tokens=True)

        raw_sql = extract_first_select(text) or text
        sql = guarded_postprocess(raw_sql, sample['nlq'])

        meta = runner_check.run(sql, capture_df=False)
        va = meta.success
        ex_ok, _, _ = execution_accuracy(engine=engine, pred_sql=sql, gold_sql=sample['sql'])
        err = meta.error
        print(f"Q: {sample['nlq']}\nSQL: {sql}\nVA: {va} EX: {ex_ok}")
        if not va:
            print(f"ERR: {err}")
        print()

run_quick_check(k=0)
run_quick_check(k=3)


### Import deterministic guards

**What this cell does**: loads projection/intent/schema heuristics used by the agent.

**Explain with**: `3_AGENT_DESIGN.md`, `6_LIMITATIONS.md`

**Code**: `nl2sql/agent_utils.py`, `nl2sql/postprocess.py`


In [None]:
# Helper imports (optional; used for interactive inspection)
# Main agent loop is in `nl2sql/agent.py`.
from nl2sql.agent_utils import intent_constraints, semantic_score, count_select_columns


## Reference map (Code ↔ Literature)

- ReAct loop: `REFERENCES.md#ref-yao2023-react`
- Execution feedback: `REFERENCES.md#ref-zhai2025-excot`
- Validity/constraints: `REFERENCES.md#ref-scholak2021-picard`
- Schema linking: `REFERENCES.md#ref-zhu2024-survey`, `REFERENCES.md#ref-li2023-resdsql`
- TS evaluation: `REFERENCES.md#ref-zhong2020-ts`

**Code**: `nl2sql/agent_tools.py`, `nl2sql/prompts.py`, `nl2sql/agent_utils.py`, `nl2sql/eval.py`


### Reload schema + runner (full evaluation mode)

**What this cell does**: refreshes `SCHEMA_SUMMARY`, `test_set`, and `runner` before evaluation.

**Explain with**: `4_EVALUATION.md`


In [None]:
# 4) Schema summary + test set + QueryRunner
import json
from pathlib import Path
from nl2sql.schema import build_schema_summary
from nl2sql.query_runner import QueryRunner

DB_NAME = os.getenv("DB_NAME") or "classicmodels"

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)
# Schema summary is used in prompts to ground column/table choices.
test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
test_set = full_set  # change to full_set[:20] when debugging

print("Loaded test set size:", len(test_set))
runner = QueryRunner(engine)  # QueryRunner enforces SELECT-only execution and records errors for VA/EX.


### Defensive re‑import (notebook stability)

**What this cell does**: keeps later cells stable after partial reruns.


In [None]:
# 5) Agent utilities + guardrails
from nl2sql.agent_utils import (
    intent_constraints,
    classify_intent,
    clean_candidate_with_reason,
    enforce_projection_contract,
    vanilla_candidate,
)
from nl2sql.postprocess import guarded_postprocess


## 6. Tool-Driven ReAct Loop (Thought → Action → Observation)

**Code**: `nl2sql/agent_tools.py`, `nl2sql/prompts.py`, this cell (`react_sql`)


##  Walkthrough (Information Flow)

### 1) Inputs
- NLQ: a user question (e.g., “List customers in France with their credit limits”).
- Schema: the database tables/columns, used to ground the model.
- Runner: a safe, SELECT‑only executor that returns errors or results.

```python
# High‑level inputs (conceptual)
NLQ = "List customers in France with their credit limits."
SCHEMA_TEXT = schema_summary
RUNNER = QueryRunner(engine)
```

### 2) The Tool‑Driven Loop
The model does **not** answer in one shot. It follows a sequence of actions and uses observations from each step.

```python
# Simplified loop (conceptual)
get_schema → link_schema → extract_constraints → generate_sql
→ validate_sql → validate_constraints → run_sql
→ (if failure) reflect_sql → retry
```

### 3) Candidate Generation + Guards
The model proposes SQL, then deterministic guards clean and validate it before execution.

```python
raw_sql = generate_sql(NLQ, schema_text)
clean_sql = clean_candidate_with_reason(raw_sql)
checked_sql = validate_sql(clean_sql, schema_text)
```

### 4) Execution + Observation
The SQL is executed safely. Success or error becomes the **Observation** used in the next step.

```python
result = run_sql(checked_sql)
# Observation example:
# "Execution error: unknown column customers.creditLimit"
```

### 5) Repair (If Needed)
If a step fails, the loop forces a repair action rather than discarding the attempt.

```python
if not result.success:
    fixed_sql = repair_sql(NLQ, bad_sql=checked_sql, error=result.error)
```

### 6) Output + Trace
Every action/observation is logged. This makes the loop explainable in demos and defensible in evaluation.

```python
final_sql, trace = react_sql(nlq=NLQ, schema_text=SCHEMA_TEXT)
# trace contains step‑by‑step actions + observations
```

**Why this matters:** the loop separates *validity* (does it run?) from *semantic correctness* (does it answer the question), and makes failures visible instead of hidden by ranking or filtering.


### Define the tool‑driven ReAct loop

**What this cell does**: binds tool context and defines `react_sql(...)` (Thought → Action → Observation).

**Explain with**: `3_AGENT_DESIGN.md`, `7_REACT_DIAGRAMS.md`

**Code**: `nl2sql/agent_tools.py`, `nl2sql/prompts.py`


In [None]:
# 6) Tool-driven ReAct loop (explicit Thought/Action/Observation)
import json
import re
import torch
import time
from nl2sql.prompts import REACT_SYSTEM_PROMPT
from nl2sql.agent_tools import (
    AgentContext,
    set_agent_context,
    get_schema,
    schema_to_text,
    link_schema,
    get_table_samples,
    generate_sql,
    extract_constraints,
    validate_sql,
    validate_constraints,
    run_sql,
    repair_sql,
    finish,
)

# Tool rationale (why each tool exists):
# - get_schema: ground the model in real tables/columns to avoid hallucinations.
# - schema_to_text: convert schema to a readable prompt format.
# - link_schema: narrow schema context to likely tables, reducing wrong joins.
# - extract_constraints: capture structure cues (COUNT/GROUP BY/LIMIT) from the NLQ.
# - generate_sql: model proposes a candidate SQL query.
# - validate_sql: catch formatting/schema errors before execution.
# - validate_constraints: enforce structural intent (e.g., missing GROUP BY).
# - run_sql: execution gate that produces the key Observation in ReAct.
# - repair_sql: forced recovery step when validation/execution fails.
# - get_table_samples: optional grounding aid for ambiguous columns.
# - finish: finalize only after a successful run_sql.

# Configure tool context (single source for engine/model/runner)
set_agent_context(
    AgentContext(
        engine=engine,
        db_name=DB_NAME,
        model=model,
        tok=tok,
        runner=runner,
        max_new_tokens=128,
    )
)

# ReAct loop hyperparameters (tuned for stability + cost)
# - REACT_MAX_STEPS: bound loop length for auditability
# - REACT_MAX_NEW_TOKENS: cap per-step generation to avoid run-on text
# - REACT_DO_SAMPLE: deterministic by default for reproducibility
# - REACT_TEMPERATURE / REACT_TOP_P: sampling controls if enabled
# - USE_LINK_SCHEMA: prune schema to reduce wrong joins
# - MAX_CLEAN_REJECT_RETRIES: allow one regenerate after guardrails reject
REACT_MAX_STEPS = 8
REACT_MAX_NEW_TOKENS = 256
REACT_DO_SAMPLE = False
REACT_TEMPERATURE = 0.2
REACT_TOP_P = 0.9
USE_LINK_SCHEMA = True  # can be overridden by quick-test toggles later
MAX_CLEAN_REJECT_RETRIES = 1  # force one re-generate if guardrails return empty

# Parse model Action lines like: Action: tool_name[json_args]
# Important: models sometimes emit multiple Action blocks in one response.
# We parse all Action lines and take the last as the model's final decision.
_ACTION_RE = re.compile(
    r'^\s*Action:\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\[(.*?)\]\s*$',
    re.IGNORECASE | re.MULTILINE | re.DOTALL,
)


def _call_react_llm(history: str) -> str:
    # Rationale: run the model with the ReAct system prompt + running history.
    messages = [
        {"role": "system", "content": REACT_SYSTEM_PROMPT},
        {"role": "user", "content": history},
    ]
    input_ids = tok.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device)
    # Some tokenizers share pad/eos ids, which prevents generate() from inferring an
    # attention mask reliably. Our prompts are not padded, so an all-ones mask is valid.
    attention_mask = torch.ones_like(input_ids)

    gen_kwargs = {
        "max_new_tokens": REACT_MAX_NEW_TOKENS,
        "do_sample": REACT_DO_SAMPLE,
        "attention_mask": attention_mask,
        "pad_token_id": getattr(tok, "pad_token_id", getattr(tok, "eos_token_id", None)),
        "eos_token_id": getattr(tok, "eos_token_id", None),
    }
    if REACT_DO_SAMPLE:
        gen_kwargs.update({"temperature": REACT_TEMPERATURE, "top_p": REACT_TOP_P})

    with torch.no_grad():
        out = model.generate(input_ids, **gen_kwargs)

    gen_ids = out[0][input_ids.shape[-1] :]
    gen_text = tok.decode(gen_ids, skip_special_tokens=True)
    return gen_text.strip()


def _normalize_llm_text(text: str) -> str:
    # Rationale: models sometimes wrap actions in code fences or add trailing prose.
    # We strip common wrappers so Action parsing is stable.
    t = (text or "").replace("```json", "```").replace("```sql", "```")
    t = re.sub(r"```(.*?)```", r"\1", t, flags=re.DOTALL)
    return t.strip()


def _parse_action(text: str) -> tuple[str | None, dict]:
    # Rationale: extract the last Action so we follow the most recent tool choice.
    text = _normalize_llm_text(text)
    matches = list(_ACTION_RE.finditer(text))
    if not matches:
        return None, {}
    m = matches[-1]
    name = m.group(1).strip()
    raw_args = (m.group(2) or "").strip()
    if not raw_args:
        return name, {}
    try:
        parsed = json.loads(raw_args)
    except Exception:
        return name, {}
    return name, parsed if isinstance(parsed, dict) else {}


def _canonicalize_table_casing(sql: str, schema_text: str) -> str:
    # Rationale: normalize table casing to match schema for clearer traces.
    if not sql or not schema_text:
        return sql
    tables = []
    for line in schema_text.splitlines():
        if "(" in line and ")" in line:
            tables.append(line.split("(", 1)[0].strip())
    out = sql
    for t in tables:
        out = re.sub(rf"\b{re.escape(t)}\b", t, out, flags=re.IGNORECASE)
    return out


def _print_guardrail_stage(label: str, before: str, after: str, *, reason: str | None = None, max_chars: int = 320, show_unchanged: bool = False) -> None:
    # Rationale: explain what each guardrail changed (if anything).
    before = "" if before is None else str(before)
    after = "" if after is None else str(after)
    if reason and not after:
        print(f"{label}: reject ({reason})")
        return
    if before.strip() == after.strip():
        if show_unchanged:
            print(f"{label}: unchanged")
        return
    print(f"{label}: changed")
    print("  before:", _truncate_text(before, max_chars=max_chars))
    print("  after: ", _truncate_text(after, max_chars=max_chars))


def _apply_guardrails(raw_sql: str, nlq: str, schema_text: str) -> tuple[str, str | None, dict]:
    # Rationale: deterministic cleanup before validation/execution to keep behavior explainable.
    stages: dict = {"raw": raw_sql}
    sql, reason = clean_candidate_with_reason(raw_sql)
    stages["clean"] = sql
    stages["clean_reason"] = reason
    if not sql:
        return "", f"clean_reject:{reason}", stages
    post = guarded_postprocess(sql, nlq)
    stages["postprocess"] = post
    proj = enforce_projection_contract(post, nlq)
    stages["projection"] = proj
    canon = _canonicalize_table_casing(proj, schema_text)
    stages["casing"] = canon
    return canon, None, stages

def log_decision(decisions: list[dict], step: int, decision: str, reason: str, data: dict | None = None, status: str = "ok") -> dict:
    entry = {"step": step, "decision": decision, "reason": reason, "status": status}
    if data is not None:
        entry["data"] = data
    decisions.append(entry)
    return entry


def format_decision_log(decisions: list[dict], max_items: int | None = 20) -> str:
    if not decisions:
        return "(no decisions logged)"
    out: list[str] = []
    limit = max_items or len(decisions)
    for d in decisions[:limit]:
        line = f"[step {d.get('step')}] {d.get('decision')} — {d.get('reason')} ({d.get('status')})"
        out.append(line)
        data = d.get("data")
        if data is not None:
            try:
                snippet = json.dumps(data, ensure_ascii=False)
            except Exception:
                snippet = str(data)
            if len(snippet) > 400:
                snippet = snippet[:397] + "..."
            out.append(f"  data: {snippet}")
    return "\n".join(out)


def summarize_trace(trace: list[dict]) -> dict:
    actions = [t.get("action") for t in trace if t.get("action")]
    attempted = [t.get("attempted_action") for t in trace if t.get("attempted_action") is not None]
    blocked_steps = sum(1 for t in trace if t.get("blocked"))
    forced_repairs = [t for t in trace if t.get("forced_action") == "repair_sql"]
    repair_count = sum(1 for t in trace if t.get("action") == "repair_sql")
    errors: list[str] = []
    for i, a in enumerate(actions):
        if a == "generate_sql" and "extract_constraints" not in actions[:i]:
            errors.append("generate_without_constraints")
        if a == "run_sql" and "validate_sql" not in actions[:i]:
            errors.append("run_without_validate")
        if a == "run_sql" and "validate_constraints" not in actions[:i]:
            errors.append("run_without_validate_constraints")
        if a == "finish" and "run_sql" not in actions[:i]:
            errors.append("finish_without_run")
    compliance_ok = len(errors) == 0
    return {
        "actions": actions,
        "attempted_actions": attempted,
        "blocked_steps": blocked_steps,
        "repairs": repair_count,
        "forced_repairs": len(forced_repairs),
        "compliance_ok": compliance_ok,
        "compliance_errors": errors,
    }


def _truncate_text(s: str, max_chars: int = 1200) -> str:
    if s is None:
        return ""
    s = str(s)
    if max_chars <= 0 or len(s) <= max_chars:
        return s
    return s[: max_chars - 3] + "..."


def _print_prompt_tail(prompt: str, *, tail_lines: int = 30, max_line_chars: int = 200) -> None:
    if not prompt:
        print("(empty prompt)")
        return
    lines = prompt.splitlines()
    tail = lines[-max(1, int(tail_lines)) :]
    for ln in tail:
        if max_line_chars and len(ln) > max_line_chars:
            ln = ln[: max_line_chars - 3] + "..."
        print(ln)


_TOOL_EXPLAIN: dict[str, str] = {
    "get_schema": "Look up the database tables/columns (ground truth).",
    "link_schema": "Focus on the most relevant tables/columns for this question.",
    "extract_constraints": "Infer structural needs (COUNT, GROUP BY, LIMIT, DISTINCT, etc.).",
    "get_table_samples": "Fetch a few example rows to ground ambiguous columns.",
    "generate_sql": "Draft a SQL query for the question.",
    "validate_sql": "Check the SQL is safe/valid (single SELECT + known schema refs).",
    "validate_constraints": "Check the SQL matches the question's required structure.",
    "run_sql": "Run the SQL against the DB and observe results/errors.",
    "repair_sql": "Fix the SQL using the latest error feedback.",
    "finish": "Return the final SQL (only after a successful run).",
}


def _extract_last_thought(text: str) -> str:
    if not text:
        return ""
    for line in reversed(text.splitlines()):
        if line.strip().lower().startswith("thought:"):
            return line.split(":", 1)[1].strip()
    return ""


def _friendly_progress(
    *,
    constraints: dict | None,
    last_sql: str | None,
    last_valid: bool | None,
    last_constraints_ok: bool | None,
    last_run: dict | None,
) -> str:
    parts = []
    parts.append("constraints: " + ("done" if constraints else "pending"))
    parts.append("sql draft: " + ("done" if last_sql else "pending"))
    if last_valid is None:
        parts.append("sql check: pending")
    else:
        parts.append("sql check: " + ("pass" if last_valid else "fail"))
    if last_constraints_ok is None:
        parts.append("shape check: pending")
    else:
        parts.append("shape check: " + ("pass" if last_constraints_ok else "fail"))
    if last_run is None:
        parts.append("run: pending")
    else:
        parts.append("run: " + ("pass" if last_run.get("success") else "fail"))
    return " | ".join(parts)



def react_sql(
    *,
    nlq: str,
    schema_text: str | None = None,
    schema_summary: str | None = None,
    exemplars: list[dict] | None = None,
    max_steps: int = REACT_MAX_STEPS,
    debug: bool = False,
    debug_sleep_s: float = 0.0,
    debug_prompt_tail_lines: int = 0,
    debug_rows_preview: int = 3,
    auto_order: bool = False,  # If True, force the next required tool step (demo-friendly).
) -> tuple[str, list[dict], list[dict]]:
    trace: list[dict] = []
    history: list[str] = []
    decision_log: list[dict] = []

    schema = get_schema()
    schema_text_full = schema_to_text(schema)
    schema_text_focus = schema_text_full

    schema_tables = [line.split("(", 1)[0].strip() for line in schema_text_full.splitlines() if "(" in line]

    # Trace bootstrap (required): user question + get_schema + link_schema
    history.append(f"User question: {nlq}")
    history.append("Action: get_schema[{}]")
    history.append(f"Observation: {schema_text_full}")
    log_decision(decision_log, -1, "get_schema", "loaded schema", {"tables": schema_tables})

    link_obs = link_schema(nlq, schema_text_full, max_tables=6 if USE_LINK_SCHEMA else 0)
    schema_text_focus = link_obs.get("schema_text") or schema_text_full
    history.append('Action: link_schema[{"max_tables": 6}]')
    history.append(f"Observation: {schema_text_focus}")
    log_decision(decision_log, -1, "link_schema", "prune schema context", link_obs)

    if debug:
        print("=" * 80)
        print("ReAct walkthrough (tool-driven NL->SQL)")
        print("NLQ:", nlq)
        print("What you'll see: a small set of tools run in order until the SQL executes.")
        print("Order: schema -> focus -> requirements -> draft -> checks -> run -> finish")
        print()
        print("[Bootstrap]")
        print(f"- schema tables: {len(schema_tables)}")
        focus_lines = schema_text_focus.splitlines()
        print(f"- link_schema enabled: {bool(USE_LINK_SCHEMA)} | changed: {bool(link_obs.get('changed'))} | focus lines: {len(focus_lines)}")
        if focus_lines:
            focus_tables = [ln.split("(", 1)[0].strip() for ln in focus_lines if "(" in ln and ")" in ln]
            focus_tables = [t for t in focus_tables if t]
            if focus_tables:
                print("- focused tables:", ", ".join(focus_tables))
            join_hint_lines = [ln for ln in focus_lines if ln.lower().startswith("join hints:")]
            if join_hint_lines:
                print("- join hints included: yes")
        print("=" * 80)

    last_sql: str | None = None
    last_error: str | None = None
    last_run: dict | None = None
    last_valid: bool | None = None
    last_constraints_ok: bool | None = None
    constraints: dict | None = None
    pending_repair_error: str | None = None
    pending_force_generate: str | None = None
    clean_reject_retries = 0

    for step in range(max_steps):
        prompt = "\n".join(history)

        if debug:
            print("\n" + "-" * 80)
            print(f"STEP {step} / {max_steps - 1}")
            print(
                _friendly_progress(
                    constraints=constraints,
                    last_sql=last_sql,
                    last_valid=last_valid,
                    last_constraints_ok=last_constraints_ok,
                    last_run=last_run,
                )
            )
            if debug_prompt_tail_lines and int(debug_prompt_tail_lines) > 0:
                print("\nTranscript tail (what the LLM sees):")
                _print_prompt_tail(prompt, tail_lines=int(debug_prompt_tail_lines))

        llm_out = _call_react_llm(prompt)
        trace.append({"step": step, "llm": llm_out})

        action, args = _parse_action(llm_out)
        if not isinstance(args, dict):
            args = {}
        attempted_action = action
        attempted_args = dict(args)
        history.append(llm_out.strip())

        if debug:
            thought = _extract_last_thought(llm_out)
            if thought:
                print("\nModel thought:", _truncate_text(thought, max_chars=240))
            print("Model action:", action)

        # If we have a pending validation/execution error, force a repair action.
        if pending_repair_error and action != "repair_sql":
            trace.append({"step": step, "forced_action": "repair_sql", "requested_action": action, "reason": pending_repair_error})
            log_decision(decision_log, step, "force_repair", pending_repair_error, {"requested_action": action})
            action = "repair_sql"
            args = {"error": pending_repair_error, "forced": True}
            history[-1] = f"Action: repair_sql[{json.dumps(args, ensure_ascii=False)}]"
            if debug:
                print(f"FORCED -> repair_sql (reason: {pending_repair_error})")

        # If guardrails returned empty SQL, force one regenerate.
        if pending_force_generate and action != "generate_sql":
            trace.append({"step": step, "forced_action": "generate_sql", "requested_action": action, "reason": pending_force_generate})
            log_decision(decision_log, step, "force_generate_sql", pending_force_generate, {"requested_action": action})
            action = "generate_sql"
            args = {"constraints": constraints} if constraints else {}
            history[-1] = f"Action: generate_sql[{json.dumps(args, ensure_ascii=False)}]"
            pending_force_generate = None
            if debug:
                print("FORCED -> generate_sql (one retry after guardrail reject)")

        if constraints is None and action not in ("extract_constraints", "repair_sql"):
            trace.append({"step": step, "forced_action": "extract_constraints", "requested_action": action, "reason": "constraints_missing"})
            log_decision(decision_log, step, "force_extract_constraints", "constraints_missing", {"requested_action": action})
            action = "extract_constraints"
            args = {}
            history[-1] = "Action: extract_constraints[{}]"
            if debug:
                print("FORCED -> extract_constraints (constraints missing)")

        # If the model tries to jump ahead (run_sql/finish), redirect to the next required step.
        # This avoids burning the step budget on blocked actions and keeps traces easier to read.
        if not auto_order and action in ("run_sql", "finish"):
            if pending_repair_error:
                required = "repair_sql"
            elif pending_force_generate:
                required = "generate_sql"
            elif constraints is None:
                required = "extract_constraints"
            elif last_sql is None:
                required = "generate_sql"
            elif last_valid is None:
                required = "validate_sql"
            elif last_valid is False:
                required = "repair_sql"
            elif last_constraints_ok is None:
                required = "validate_constraints"
            elif last_constraints_ok is False:
                required = "repair_sql"
            elif not last_run or not last_run.get("success"):
                required = "run_sql"
            else:
                required = "finish"

            if action != required:
                trace.append({"step": step, "forced_action": required, "requested_action": action, "reason": "controller_order"})
                log_decision(decision_log, step, "force_order", "controller_order", {"requested_action": action, "required_action": required})
                if debug:
                    print(f"FORCED -> {required} (required before {action})")
                action = required
                if required == "generate_sql":
                    args = {"constraints": constraints} if constraints else {}
                elif required == "repair_sql":
                    args = {"error": pending_repair_error or last_error or ""} if (pending_repair_error or last_error) else {}
                else:
                    args = {}
                history[-1] = f"Action: {required}[{json.dumps(args, ensure_ascii=False)}]"

        # Demo-friendly strict ordering: force the next required step based on state.
        # This keeps walkthroughs understandable even when the model proposes out-of-order actions.
        if auto_order:
            required: str
            if pending_repair_error:
                required = "repair_sql"
            elif pending_force_generate:
                required = "generate_sql"
            elif constraints is None:
                required = "extract_constraints"
            elif last_sql is None:
                required = "generate_sql"
            elif last_valid is None:
                required = "validate_sql"
            elif last_valid is False:
                required = "repair_sql"
            elif last_constraints_ok is None:
                required = "validate_constraints"
            elif last_constraints_ok is False:
                required = "repair_sql"
            elif not last_run or not last_run.get("success"):
                required = "run_sql"
            else:
                required = "finish"

            if action != required:
                trace.append({"step": step, "forced_action": required, "requested_action": action, "reason": "auto_order"})
                log_decision(decision_log, step, "force_order", "auto_order", {"requested_action": action, "required_action": required})
                if debug:
                    print(f"FORCED -> {required} (next required step)")
                action = required
                if required == "generate_sql":
                    args = {"constraints": constraints} if constraints else {}
                elif required == "repair_sql":
                    args = {"error": pending_repair_error or last_error or ""} if (pending_repair_error or last_error) else {}
                else:
                    args = {}
                history[-1] = f"Action: {required}[{json.dumps(args, ensure_ascii=False)}]"

        # If we already ran successfully, force finish to avoid extra tool calls.
        if last_run and last_run.get("success") and action != "finish":
            trace.append({"step": step, "forced_action": "finish", "requested_action": action, "reason": "already_successful_run"})
            log_decision(decision_log, step, "force_finish", "already_successful_run", {"requested_action": action})
            if debug:
                print("FORCED -> finish (run_sql already succeeded)")
            action = "finish"
            args = {}
            history[-1] = "Action: finish[{}]"

        if action is None:
            obs = {"error": "No Action found. Respond with Action: tool[json_args]."}
            history.append(f"Observation: {json.dumps(obs, ensure_ascii=False)}")
            trace.append(
                {
                    "step": step,
                    "attempted_action": attempted_action,
                    "attempted_args": attempted_args,
                    "action": None,
                    "args": {},
                    "observation": obs,
                    "blocked": True,
                }
            )
            continue

        if action not in TOOLS:
            obs = {"error": f"Unknown action: {action}"}
            history.append(f"Observation: {json.dumps(obs, ensure_ascii=False)}")
            trace.append(
                {
                    "step": step,
                    "attempted_action": attempted_action,
                    "attempted_args": attempted_args,
                    "action": None,
                    "args": {},
                    "observation": obs,
                    "blocked": True,
                }
            )
            continue

        # Setup-only tools should not be executed inside the main loop.
        if action in ("get_schema", "link_schema"):
            obs = {"error": f"{action} is setup-only and already executed."}
            history.append(f"Observation: {json.dumps(obs, ensure_ascii=False)}")
            trace.append(
                {
                    "step": step,
                    "attempted_action": attempted_action,
                    "attempted_args": attempted_args,
                    "action": None,
                    "args": {},
                    "observation": obs,
                    "blocked": True,
                }
            )
            log_decision(decision_log, step, "blocked_setup_action", action, {"attempted_action": attempted_action})
            if debug:
                print(f"BLOCKED setup action: {action}")
            continue

        # Enforce: run_sql must succeed before finish.
        if action == "finish":
            # Rationale: finish is only allowed after a successful execution.
            if not last_run or not last_run.get("success"):
                obs = {"error": "Must call run_sql successfully before finish."}
                history.append(f"Observation: {json.dumps(obs, ensure_ascii=False)}")
                trace.append(
                    {
                        "step": step,
                        "attempted_action": attempted_action,
                        "attempted_args": attempted_args,
                        "action": None,
                        "args": {},
                        "observation": obs,
                        "blocked": True,
                    }
                )
                if debug:
                    print("FINISH blocked:", obs["error"])
                continue
            result = finish(answer=str(last_run.get("rows", [])), sql=last_sql or "", provenance={"trace": trace})
            trace.append(
                {
                    "step": step,
                    "attempted_action": attempted_action,
                    "attempted_args": attempted_args,
                    "action": "finish",
                    "args": {},
                    "observation": result,
                }
            )
            log_decision(decision_log, step, "finish", "completed", {"sql": result.get("sql", "")})
            if debug:
                print("\nFINISH -> returning final SQL")
                print(result.get("sql", ""))
            return result.get("sql", ""), trace, decision_log

        if debug:
            expl = _TOOL_EXPLAIN.get(action, "")
            if expl:
                print(f"\nTool: {action} — {expl}")
            else:
                print(f"\nTool: {action}")

        executed_action = action
        blocked = False
        auto_finish = False
        auto_finish_payload = None

        # Tool execution
        if action == "get_schema":
            obs = schema_text_full
            schema_text_focus = schema_text_full
            if debug:
                print("\nSchema loaded.")
        elif action == "link_schema":
            # Rationale: prunes schema context to reduce wrong-table joins and overlong prompts.
            max_tables = int(args.get("max_tables", 6)) if str(args.get("max_tables", "")).isdigit() else 6
            res = link_schema(nlq, schema_text_full, max_tables=max_tables if USE_LINK_SCHEMA else 0)
            res["enabled"] = bool(USE_LINK_SCHEMA)
            schema_text_focus = res.get("schema_text") or schema_text_full
            obs = res
            if debug:
                schema_preview = (res.get("schema_text") or "").strip()
                focus_lines = schema_preview.splitlines() if schema_preview else []
                focus_tables = [ln.split("(", 1)[0].strip() for ln in focus_lines if "(" in ln and ")" in ln]
                focus_tables = [t for t in focus_tables if t]
                if focus_tables:
                    print("\nFocused tables:", ", ".join(focus_tables))
                else:
                    print("\nFocused tables: (none)")
        elif action == "extract_constraints":
            # Rationale: structural cues (COUNT/GROUP BY/LIMIT) are frequent EX failure points.
            res = extract_constraints(nlq)
            constraints = res
            last_constraints_ok = None
            obs = res
            log_decision(decision_log, step, "extract_constraints", "heuristic extraction", res)
            if debug:
                print("\nRequirements extracted:", res)
        elif action == "get_table_samples":
            table = args.get("table")
            n = int(args.get("n", 3)) if str(args.get("n", "")).isdigit() else 3
            obs = get_table_samples(table, n=n)
            if debug:
                print(f"\nSample rows from {table!r} (n={n}):")
                try:
                    print(_truncate_text(json.dumps(obs, ensure_ascii=False, default=str), max_chars=1200))
                except Exception:
                    print(_truncate_text(str(obs), max_chars=1200))
        elif action == "generate_sql":
            # Rationale: model generation step; guardrails immediately clean + normalize output.
            constraints = args.get("constraints") or constraints or {"intent": classify_intent(nlq)}
            if debug:
                print("\nRequirements:", constraints)
            raw_sql = generate_sql(nlq, schema_text_focus, constraints)
            log_decision(decision_log, step, "generate_sql", "model generation", {"raw_sql": raw_sql})
            sql, reason, stages = _apply_guardrails(raw_sql, nlq, schema_text_full)
            if debug:
                print("\nDraft SQL (raw):")
                print(_truncate_text(raw_sql, max_chars=2000))
                if stages:
                    print("\nGuardrails effects:")
                    _print_guardrail_stage("clean_candidate", raw_sql, stages.get("clean"), reason=stages.get("clean_reason"))
                    if "postprocess" in stages:
                        _print_guardrail_stage("guarded_postprocess", stages.get("clean"), stages.get("postprocess"))
                    if "projection" in stages:
                        _print_guardrail_stage("projection_contract", stages.get("postprocess"), stages.get("projection"))
                    if "casing" in stages:
                        _print_guardrail_stage("canonicalize_casing", stages.get("projection"), stages.get("casing"))
            if not sql:
                obs = {"error": reason, "raw_sql": raw_sql, "hint": "Output a single SELECT statement only."}
                log_decision(decision_log, step, "guardrails", "clean_reject", {"reason": reason, "raw_sql": raw_sql}, status="reject")
                if clean_reject_retries < MAX_CLEAN_REJECT_RETRIES:
                    pending_force_generate = reason
                    clean_reject_retries += 1
                if debug:
                    print("\nGuardrails: REJECT -", reason)
            else:
                last_sql = sql
                last_error = None
                last_valid = None
                last_constraints_ok = None
                pending_repair_error = None
                pending_force_generate = None
                obs = {"sql": sql}
                log_decision(decision_log, step, "guardrails", "cleaned", {"cleaned_sql": sql})
                if debug:
                    print("\nSQL after guardrails:")
                    print(sql)
        elif action == "repair_sql":
            # Rationale: forced recovery when validation/execution fails.
            if not last_sql:
                obs = {"error": "No SQL to repair. Call generate_sql first."}
                blocked = True
                executed_action = None
            else:
                err = args.get("error") or last_error or ""
                if debug:
                    print("\nPrevious SQL:")
                    print(last_sql)
                raw_sql = repair_sql(nlq, last_sql, err, schema_text_full)
                log_decision(decision_log, step, "repair_sql", "model repair", {"error": err, "raw_sql": raw_sql})
                sql, reason, stages = _apply_guardrails(raw_sql, nlq, schema_text_full)
                if debug:
                    print("\nError to fix:", err)
                    print("\nRepair draft (raw):")
                    print(_truncate_text(raw_sql, max_chars=2000))
                    if stages:
                        print("\nGuardrails effects:")
                        _print_guardrail_stage("clean_candidate", raw_sql, stages.get("clean"), reason=stages.get("clean_reason"))
                        if "postprocess" in stages:
                            _print_guardrail_stage("guarded_postprocess", stages.get("clean"), stages.get("postprocess"))
                        if "projection" in stages:
                            _print_guardrail_stage("projection_contract", stages.get("postprocess"), stages.get("projection"))
                        if "casing" in stages:
                            _print_guardrail_stage("canonicalize_casing", stages.get("projection"), stages.get("casing"))
                if not sql:
                    obs = {"error": reason, "raw_sql": raw_sql}
                    log_decision(decision_log, step, "guardrails", "clean_reject", {"reason": reason, "raw_sql": raw_sql}, status="reject")
                    if debug:
                        print("\nGuardrails: REJECT -", reason)
                else:
                    last_sql = sql
                    last_valid = None
                    last_constraints_ok = None
                    pending_repair_error = None
                    obs = {"sql": sql}
                    log_decision(decision_log, step, "guardrails", "cleaned", {"cleaned_sql": sql})
                    if debug:
                        print("\nSQL after guardrails:")
                        print(sql)
        elif action == "validate_sql":
            # Rationale: catch schema/format errors before hitting the database.
            if not last_sql:
                obs = {"error": "No SQL to validate. Call generate_sql first."}
                blocked = True
                executed_action = None
            else:
                res = validate_sql(last_sql, schema_text_full)
                if res.get("reason") == "no_schema":
                    res = {"valid": False, "reason": "schema_missing"}
                obs = res
                last_valid = bool(res.get("valid"))
                data = dict(res)
                data["schema_text_len"] = len(schema_text_full or "")
                log_decision(decision_log, step, "validate_sql", res.get("reason", ""), data, status="ok" if last_valid else "reject")
                if not last_valid:
                    last_error = res.get("reason")
                    pending_repair_error = last_error
                else:
                    pending_repair_error = None
            if debug:
                if obs.get("error"):
                    print("\nSQL check: blocked -", obs.get("error"))
                elif obs.get("valid"):
                    print("\nSQL check: PASS")
                else:
                    print("\nSQL check: FAIL -", obs.get("reason"))
        elif action == "validate_constraints":
            # Rationale: enforce NLQ-implied structure (aggregation, grouping, limits).
            if not last_sql:
                obs = {"error": "No SQL to validate. Call generate_sql first."}
                blocked = True
                executed_action = None
            elif not constraints:
                obs = {"error": "No constraints found. Call extract_constraints first."}
                blocked = True
                executed_action = None
            else:
                res = validate_constraints(last_sql, constraints)
                obs = res
                last_constraints_ok = bool(res.get("valid"))
                log_decision(decision_log, step, "validate_constraints", res.get("reason", ""), res, status="ok" if last_constraints_ok else "reject")
                if not last_constraints_ok:
                    last_error = res.get("reason")
                    pending_repair_error = last_error
                else:
                    pending_repair_error = None
            if debug:
                if obs.get("error"):
                    print("\nShape check: blocked -", obs.get("error"))
                elif obs.get("valid"):
                    print("\nShape check: PASS")
                else:
                    print("\nShape check: FAIL -", obs.get("reason"))
        elif action == "run_sql":
            # Rationale: execution is the ReAct Observation; it tells the loop what failed.
            if not last_sql:
                obs = {"error": "No SQL to run. Call generate_sql first."}
                blocked = True
                executed_action = None
            elif last_valid is None:
                obs = {"error": "Must call validate_sql before run_sql."}
                blocked = True
                executed_action = None
            elif last_valid is False:
                obs = {"error": "Validation failed. Call repair_sql."}
                blocked = True
                executed_action = None
            elif last_constraints_ok is None:
                obs = {"error": "Must call validate_constraints before run_sql."}
                blocked = True
                executed_action = None
            elif last_constraints_ok is False:
                obs = {"error": "Constraint validation failed. Call repair_sql."}
                blocked = True
                executed_action = None
            else:
                res = run_sql(last_sql)
                log_decision(decision_log, step, "run_sql", "execute", {"success": res.get("success"), "rowcount": res.get("rowcount"), "error": res.get("error")})
                if res.get("success"):
                    ok, why = intent_constraints(nlq, last_sql)
                    if not ok:
                        res = {"success": False, "error": f"Intent mismatch: {why}"}
                        log_decision(decision_log, step, "intent_check", why, {"ok": ok}, status="reject")
                    else:
                        log_decision(decision_log, step, "intent_check", "ok", {"ok": ok})
                        auto_finish = True
                        auto_finish_payload = {"answer": str(res.get("rows", [])), "sql": last_sql or ""}
                obs = res
                last_run = res
                if not res.get("success"):
                    last_error = res.get("error")
                    pending_repair_error = last_error
                else:
                    pending_repair_error = None
            if debug:
                if obs.get("success"):
                    print("\nRun: PASS (rows:", obs.get("rowcount"), ")")
                    rows = obs.get("rows") or []
                    if rows:
                        preview = rows[: max(0, int(debug_rows_preview))]
                        line = _truncate_text(json.dumps(preview, ensure_ascii=False, default=str), max_chars=1200)
                        print("Rows preview:", line)
                else:
                    print("\nRun: FAIL -", obs.get("error"))
        else:
            obs = {"error": f"Unhandled action: {action}"}

        history.append(f"Observation: {json.dumps(obs, ensure_ascii=False, default=str)}")
        trace.append(
            {
                "step": step,
                "attempted_action": attempted_action,
                "attempted_args": attempted_args,
                "action": executed_action,
                "args": args if executed_action else {},
                "observation": obs,
                "blocked": blocked,
            }
        )

        if auto_finish and auto_finish_payload is not None:
            result = finish(
                answer=auto_finish_payload["answer"],
                sql=auto_finish_payload["sql"],
                provenance={"trace": trace},
            )
            trace.append(
                {
                    "step": step,
                    "attempted_action": attempted_action,
                    "attempted_args": attempted_args,
                    "action": "finish",
                    "args": {},
                    "observation": result,
                    "forced_action": "finish",
                    "reason": "auto_finish_after_success",
                }
            )
            log_decision(decision_log, step, "finish", "auto_finish_after_success", {"sql": result.get("sql", "")})
            if debug:
                print("\nAUTO-FINISH -> returning final SQL")
                print(result.get("sql", ""))
            return result.get("sql", ""), trace, decision_log

        if debug and float(debug_sleep_s) > 0:
            time.sleep(float(debug_sleep_s))

    # Fallback if the loop did not finish
    fallback = None
    if schema_summary:
        fallback = vanilla_candidate(
            nlq=nlq,
            schema_summary=schema_summary,
            tok=tok,
            model=model,
            exemplars=exemplars or [],
        )
    if fallback:
        trace.append({"step": max_steps, "action": "fallback", "sql": fallback})
        log_decision(decision_log, max_steps, "fallback", "vanilla candidate", {"sql": fallback})
        return fallback, trace, decision_log
    return last_sql or "", trace, decision_log


## EX Troubleshooting Checklist (VA high, EX low)

- Projection drift → `enforce_projection_contract`
- Intent mismatch → `intent_constraints`
- Wrong tables/joins → check `link_schema`
- Missing literals → check constraints + filters

**Explain with**: `5_ITERATIVE_REFINEMENTS.md`


### Quick sanity check (trace + decision log)

**What this cell does**: runs a small slice and prints VA, intent checks, trace summary, and decisions.

**Explain with**: `context.md` (trace fields), `EXAMINER_QA.md`


In [None]:
# 7a) Interactive walkthrough: type an NLQ and watch the loop step-by-step
DEMO_INTERACTIVE = True
DEMO_DEFAULT_NLQ = "Which customers are in France?"
DEMO_AUTO_ORDER = True  # keep the walkthrough linear (forces the next required step)
DEMO_SLEEP_S = 0.8  # set 0 for fast
DEMO_PROMPT_TAIL = 0  # set >0 to show the transcript tail the model sees
SHOW_DECISIONS = False

nlq = ""
if DEMO_INTERACTIVE:
    try:
        nlq = input("Type a ClassicModels question (blank uses default): ").strip()
    except Exception:
        nlq = ""
if not nlq:
    nlq = DEMO_DEFAULT_NLQ

pred, trace, decisions = react_sql(
    nlq=nlq,
    schema_summary=SCHEMA_SUMMARY,
    exemplars=REACT_EXEMPLARS,
    debug=True,
    auto_order=DEMO_AUTO_ORDER,
    debug_sleep_s=DEMO_SLEEP_S,
    debug_prompt_tail_lines=DEMO_PROMPT_TAIL,
)

print("\nFINAL SQL:")
print(pred)
print("\nTRACE SUMMARY:", summarize_trace(trace))
if SHOW_DECISIONS:
    print("\nDECISIONS:\n" + format_decision_log(decisions, max_items=40))


In [None]:
# 7) Quick sanity check on a few items
from nl2sql.eval import execution_accuracy
DEBUG_EX = False  # set True for a quick EX check (slower)
DEBUG_TRACE = True
for sample in test_set[:5]:
    nlq = sample["nlq"]
    gold = sample["sql"]
    pred, trace, decisions = react_sql(
        nlq=nlq,
        schema_summary=SCHEMA_SUMMARY,
        exemplars=REACT_EXEMPLARS,
        auto_order=True,  # ensure ordered tool steps in sanity check
    )
    print("NLQ:", nlq)
    print("PRED:", pred)
    print("GOLD:", gold)
    if pred:
        meta = runner.run(pred, capture_df=False)
        print("VA:", int(meta.success), "ERR:", meta.error)
        ok, why = intent_constraints(nlq, pred)
        print("INTENT:", ok, why)
    else:
        print("VA:", 0, "ERR:", "no prediction")
        print("INTENT:", False, "no prediction")
    if DEBUG_EX and pred:
        ex_ok, pred_err, gold_err = execution_accuracy(engine=engine, pred_sql=pred, gold_sql=gold)
        print("EX:", int(ex_ok), "PRED_ERR:", pred_err, "GOLD_ERR:", gold_err)
    if DEBUG_TRACE and trace:
        summary = summarize_trace(trace)
        print("TRACE LEN:", len(trace))
        print("EXECUTED ACTIONS:", summary.get("actions"))
        attempted = summary.get("attempted_actions") or []
        if attempted:
            print("ATTEMPTED ACTIONS:", attempted[-10:])
        print("BLOCKED STEPS:", summary.get("blocked_steps"))
        print("COMPLIANCE:", summary.get("compliance_ok"), summary.get("compliance_errors"))
        print("TRACE SUMMARY:", summary)
        print("DECISIONS:\n" + format_decision_log(decisions, max_items=12))
        print("TRACE LAST:", trace[-1])
    else:
        print("TRACE LEN:", len(trace))
    print("-" * 80)


### Import TS evaluator

**What this cell does**: loads the TS evaluator for semantic robustness across DB replicas.

**Explain with**: `4_EVALUATION.md`, `REFERENCES.md#ref-zhong2020-ts`


In [None]:
# === Test Suite Accuracy (TS) evaluation ===
# Harness now lives in nl2sql.eval for reuse in scripts.
from nl2sql.eval import test_suite_accuracy_for_item


### Debug cost toggles

**What this cell does**: sets small limits for fast iteration (TS replicas, rows, query count).


In [None]:
# === Quick test toggles (set before full eval) ===
# Use small values to sanity‑check TS/EX before full runs.
QUICK_LIMIT = 20   # number of NLQs to evaluate (set None for full set)
TS_N = 3           # number of TS DBs (set 10 for full TS)
MAX_ROWS_TS = 500  # row cap per query in TS (raise for full)
USE_LINK_SCHEMA = True  # set False to ablate schema linking


### Full evaluation (VA/EM/EX/TS)

**What this cell does**: runs the full tool‑driven loop and saves JSON results with trace summaries.

**Explain with**: `4_EVALUATION.md`, `LOGBOOK.md`


In [None]:
# 8) Full agentic evaluation (VA/EX/EM/TS) over test_set
import json
from functools import lru_cache
from pathlib import Path
from sqlalchemy.engine import Engine
from nl2sql.eval import execution_accuracy, test_suite_accuracy_for_item
from nl2sql.postprocess import normalize_sql

results = []
TS_PREFIX = "classicmodels_ts"
SUITE_DBS = [f"{TS_PREFIX}_{i:02d}" for i in range(1, TS_N + 1)]

@lru_cache(maxsize=32)
def make_engine_cached(db_name: str) -> Engine:
    return make_engine(db_name)

def make_engine_fn(db_name: str) -> Engine:
    return make_engine_cached(db_name)

LIMIT = QUICK_LIMIT  # override from quick toggles
items = test_set[:LIMIT] if LIMIT else test_set

# Per-item evaluation: generate SQL and compute VA/EM/EX/TS.
for i, sample in enumerate(items, start=1):
    nlq = sample["nlq"]
    gold_sql = sample["sql"]
    pred_sql, trace, decisions = react_sql(
        nlq=nlq,
        schema_summary=SCHEMA_SUMMARY,
        exemplars=REACT_EXEMPLARS,
    )
    trace_summary = summarize_trace(trace)
    decision_log = decisions

    # EM is strict (normalized) string match; kept as a diagnostic signal.
    em = int(normalize_sql(pred_sql) == normalize_sql(gold_sql))

    # VA = executability of predicted SQL
    va_meta = runner.run(pred_sql, capture_df=False) if pred_sql else None
    va = int(bool(va_meta and va_meta.success))

    # EX = execution accuracy on base DB (row equivalence)
    ex = 0
    pred_err = None
    gold_err = None
    if va:
        ex_ok, pred_err, gold_err = execution_accuracy(engine=engine, pred_sql=pred_sql, gold_sql=gold_sql)
        ex = int(ex_ok)

    # TS = test-suite accuracy across replica DBs
    # Note: test_suite_accuracy_for_item returns (ts_pass, debug_info).
    ts = None
    ts_debug = None
    if va:
        ts, ts_debug = test_suite_accuracy_for_item(
            pred_sql=pred_sql,
            gold_sql=gold_sql,
            suite_db_names=SUITE_DBS,
            make_engine_fn=make_engine_fn,
            max_rows=MAX_ROWS_TS,
        )

    results.append(
        {
            "nlq": nlq,
            "gold_sql": gold_sql,
            "pred_sql": pred_sql,
            "va": va,
            "em": em,
            "ex": ex,
            "ts": ts,
            "ts_debug": ts_debug,
            "pred_err": pred_err,
            "gold_err": gold_err,
            "trace": trace,
            "trace_summary": trace_summary,
            "decision_log": decision_log,
        }
    )

    if i % 20 == 0 or i == len(items):
        print(f"Processed {i}/{len(items)}")

# Aggregate rates
va_rate = sum(r["va"] for r in results) / len(results)
em_rate = sum(r["em"] for r in results) / len(results)
ex_rate = sum(r["ex"] for r in results) / len(results)
ts_values = [r["ts"] for r in results if r.get("ts") is not None]
ts_rate = (sum(ts_values) / max(1, len(ts_values))) if ts_values else 0.0

print("ReAct VA:", round(va_rate, 3), "EX:", round(ex_rate, 3), "EM:", round(em_rate, 3), "TS:", round(ts_rate, 3))

out = {
    "va_rate": va_rate,
    "ex_rate": ex_rate,
    "em_rate": em_rate,
    "ts_rate": ts_rate,
    "items": results,
}
out_path = Path("results/agent/results_react_200.json")
out_path.write_text(json.dumps(out, indent=2, default=str))
print("Saved to", out_path)
