# Agentic Evaluation (Tool-Driven ReAct Loop)


### Setup (Colab only)


In [None]:

%%bash
set -e
export PIP_DEFAULT_TIMEOUT=120

# Clean conflicting preinstalls
pip uninstall -y torch torchvision torchaudio bitsandbytes triton transformers accelerate peft trl datasets numpy pandas fsspec requests google-auth || true

# Base deps
pip install -q --no-cache-dir --force-reinstall   numpy==1.26.4 pandas==2.2.1 fsspec==2024.5.0 requests==2.31.0 google-auth==2.43.0

# Torch + CUDA 12.1
pip install -q --no-cache-dir --force-reinstall   torch==2.3.1+cu121 torchvision==0.18.1+cu121 torchaudio==2.3.1+cu121   --index-url https://download.pytorch.org/whl/cu121

# bitsandbytes + triton + HF stack
pip install -q --no-cache-dir --force-reinstall   bitsandbytes==0.43.3 triton==2.3.1   transformers==4.44.2 accelerate==0.33.0 peft==0.17.0 trl==0.9.6 datasets==2.20.0

echo "Setup complete. Restart runtime once, then run the rest of the notebook top-to-bottom."


### Repo setup (Colab)


In [None]:
# 0) Clone repo (Colab) + install deps
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    if not os.path.exists('/content/NLtoSQL'):
        !git clone https://github.com/MacKenzieOBrian/NLtoSQL.git /content/NLtoSQL
    %cd /content/NLtoSQL
    !pip -q install -r requirements.txt
    import torch, transformers, accelerate, peft
    print('torch', torch.__version__, 'cuda', torch.cuda.is_available())
else:
    print('Not in Colab; using existing workspace')


### Optional: ADC auth


In [None]:
# Run this only if you prefer gcloud-based ADC (no JSON key)
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    %pip install -q --upgrade google-auth google-auth-oauthlib
    !gcloud auth application-default login
else:
    print("Not in Colab; skip gcloud auth.")


### Database Connection and Query Execution Layer

This block establishes a reproducible connection to ClassicModels and initializes the query runner
used for validation and execution accuracy. It ensures all downstream metrics are computed
against a consistent database state.


In [None]:
# 1) Environment + DB
from getpass import getpass

from sqlalchemy import text

from nl2sql.db import create_engine_with_connector, safe_connection

# Expected env vars (set these in a Colab cell):
# INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME
INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME") or "classicmodels"

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

# Canonical engine builder (shared with scripts + other notebooks).
# Uses Cloud SQL Connector under the hood and ADC for credentials.
engine, connector = create_engine_with_connector(
    instance_connection_name=INSTANCE_CONNECTION_NAME,
    user=DB_USER,
    password=DB_PASS,
    db_name=DB_NAME,
)

with safe_connection(engine) as conn:
    conn.execute(text("SELECT 1"))
print("DB connection OK")


### TS engine factory


In [None]:
# 1b) Engine factory for TS (multiple DB names)

import sqlalchemy
from sqlalchemy.engine import Engine


def make_engine(db_name: str) -> Engine:
    """Create a new engine bound to a specific TS replica DB name.

    TS (test-suite accuracy) executes the same (gold, pred) SQL across multiple
    replica databases (classicmodels_ts_XX). We keep separate engines so each
    replica is evaluated independently.
    """

    def getconn_for_db():
        return connector.connect(
            INSTANCE_CONNECTION_NAME,
            "pymysql",
            user=DB_USER,
            password=DB_PASS,
            db=db_name,
        )

    return sqlalchemy.create_engine("mysql+pymysql://", creator=getconn_for_db, future=True)


### Schema Summary and Test Set

We build a compact schema summary for prompt grounding and load the evaluation set used across runs.
This keeps the context small while preserving the tables and columns required for correct SQL synthesis.


In [None]:
# 2) Schema summary + test set + QueryRunner
QUICK_LIMIT = 20  # number of NLQs to evaluate (set None for full set)
import json
from pathlib import Path
from nl2sql.schema import build_schema_summary
from nl2sql.query_runner import QueryRunner

DB_NAME = globals().get("DB_NAME") or os.getenv("DB_NAME") or "classicmodels"
SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)
test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
test_set = full_set  # slice later via QUICK_LIMIT
print("Loaded test set size:", len(test_set))

# Small exemplar set (seeded to encourage join behavior).
join_exemplars = [it for it in full_set if "office" in it["nlq"].lower()]
REACT_EXEMPLARS = []
if join_exemplars:
    REACT_EXEMPLARS.append(join_exemplars[0])
for it in full_set:
    if it not in REACT_EXEMPLARS:
        REACT_EXEMPLARS.append(it)
    if len(REACT_EXEMPLARS) >= 3:
        break
print("Exemplars:", [e["nlq"] for e in REACT_EXEMPLARS])

runner = QueryRunner(engine)


### Load model


In [None]:

# 3) Load model (base or QLoRA adapters)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = os.getenv("ADAPTER_PATH") or "results/adapters/qlora_classicmodels"  # set to None to use base model

HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    HF_TOKEN = getpass("Enter HF_TOKEN (https://huggingface.co/settings/tokens): ").strip()

cc_major, cc_minor = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else (0, 0)
use_bf16 = cc_major >= 8
compute_dtype = torch.bfloat16 if use_bf16 else torch.float16
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Using bf16:", use_bf16)
print("Adapter path:", ADAPTER_PATH)

# Tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Quantized base model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
    device_map={"": 0} if torch.cuda.is_available() else None,
    token=HF_TOKEN,
)
base_model.generation_config.do_sample = False
base_model.generation_config.temperature = 1.0
base_model.generation_config.top_p = 1.0

# Load adapters if present locally; otherwise use base model
adapter_dir = Path(ADAPTER_PATH) if ADAPTER_PATH else None
if adapter_dir and adapter_dir.exists():
    model = PeftModel.from_pretrained(base_model, adapter_dir, token=HF_TOKEN)
    print("Loaded adapters from", adapter_dir)
else:
    print("Adapter path missing; using base model only. Set ADAPTER_PATH to your local adapter folder or upload it to Colab.")
    model = base_model


### Optional: smoke check


In [None]:
from nl2sql.prompting import make_few_shot_messages
from nl2sql.llm import extract_first_select
from nl2sql.postprocess import guarded_postprocess
from nl2sql.eval import execution_accuracy

from nl2sql.agent_utils import _extract_required_columns

runner_check = QueryRunner(engine)
# reuse existing test_set (default small slice); pick 3 exemplars
exemplars = test_set[:3]

def run_quick_check(k: int = 0, limit: int = 3):
    print(f"Quick check k={k}")
    for sample in test_set[:limit]:
        shots = exemplars if k > 0 else []
        msgs = make_few_shot_messages(
            schema=SCHEMA_SUMMARY,
            exemplars=shots,
            nlq=sample['nlq'],
        )
        prompt_preview = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        inputs = tok(prompt_preview, return_tensors="pt").to(model.device)
        out = model.generate(**inputs, max_new_tokens=256, do_sample=False)

        # strip the prompt before decoding the generation
        gen_ids = out[0][inputs.input_ids.shape[-1]:]
        text = tok.decode(gen_ids, skip_special_tokens=True)

        raw_sql = extract_first_select(text) or text
        explicit_fields = _extract_required_columns(sample['nlq'])
        sql = guarded_postprocess(raw_sql, sample['nlq'], explicit_fields=explicit_fields)

        meta = runner_check.run(sql, capture_df=False)
        va = meta.success
        ex_ok, _, _ = execution_accuracy(engine=engine, pred_sql=sql, gold_sql=sample['sql'])
        err = meta.error
        print(f"Q: {sample['nlq']}\nSQL: {sql}\nVA: {va} EX: {ex_ok}")
        if not va:
            print(f"ERR: {err}")
        print()

run_quick_check(k=0)
run_quick_check(k=3)



### Guardrails and Post-Processing Utilities

Guardrails provide lightweight correctness constraints (SELECT-only, schema checks, and
intent checks) that prevent "runs but wrong" SQL. Projection cleanup is conservative:
explicit projection is enforced only when the NLQ enumerates fields (or a high-precision
template requires it), and minimal-projection pruning is skipped when constraints require
specific output fields. This keeps post-processing from changing query semantics while
still reducing common formatting noise.


In [None]:
# 4) Agent utilities + guardrails
import importlib
import nl2sql.agent_utils as agent_utils
importlib.reload(agent_utils)

from nl2sql.agent_utils import (
    _extract_required_columns,
    intent_constraints,
    classify_intent,
    clean_candidate_with_reason,
)



## Tool-Driven ReAct Loop


### Define ReAct loop


# Method: Tool-driven ReAct for NL->SQL

We implement a tool-driven ReAct loop that alternates between reasoning steps and executable actions.
The objective is to ground generation in schema facts, enforce constraints before execution, and
use execution feedback to repair errors while keeping traces auditable.

Decision logs and traces are retained to support error analysis and dissertation reporting.


In [None]:
# 5) Canonical ReAct pipeline (module-backed)

import importlib

import nl2sql.agent_tools as agent_tools
import nl2sql.react_pipeline as react_pipeline

importlib.reload(agent_tools)
importlib.reload(react_pipeline)

from nl2sql.agent_tools import AgentContext, set_agent_context
from nl2sql.react_pipeline import ReactAblationConfig, run_react_pipeline, evaluate_react_ablation

# Bind runtime dependencies once for module-level tools.
set_agent_context(
    AgentContext(
        engine=engine,
        db_name=DB_NAME,
        model=model,
        tok=tok,
        runner=runner,
        max_new_tokens=128,
    )
)


## Design Rationale (Concise)

- **Schema grounding**: constrain the model to known tables/columns to reduce hallucinations.
- **Constraint extraction**: infer structure (COUNT/GROUP BY/LIMIT) from the NLQ before SQL generation.
- **Validation gates**: block invalid or structurally wrong SQL before execution.
- **Execution-guided repair**: use runtime errors to guide correction instead of free-form regeneration.
- **Determinism first**: default to non-sampling for reproducibility; enable sampling only for reranking.


In [None]:
# Canonical toggles + notebook compatibility helpers

USE_LINK_SCHEMA = True
USE_CONSTRAINT_POLICY = True
USE_REPAIR_POLICY = True
REACT_MAX_REPAIRS = 2
REACT_LINK_MAX_TABLES = 6


def _react_cfg(name: str = "notebook_react") -> ReactAblationConfig:
    return ReactAblationConfig(
        name=name,
        use_schema_link=USE_LINK_SCHEMA,
        use_constraint_policy=USE_CONSTRAINT_POLICY,
        use_repair_policy=USE_REPAIR_POLICY,
        max_repairs=REACT_MAX_REPAIRS,
        link_max_tables=REACT_LINK_MAX_TABLES,
    )


def _trace_to_decisions(trace: list[dict]) -> list[dict]:
    decisions: list[dict] = []
    for i, t in enumerate(trace or []):
        stage = t.get("stage")
        if not stage:
            continue
        d = {"step": i, "decision": stage, "reason": "ok", "status": "ok", "data": {}}

        if stage in ("validate_sql", "validate_constraints"):
            res = t.get("result") or {}
            d["reason"] = res.get("reason", "ok")
            d["status"] = "ok" if res.get("valid") else "reject"
            data = dict(res)
            if t.get("sql"):
                data["sql"] = t.get("sql")
            d["data"] = data
        elif stage == "run_sql":
            res = t.get("result") or {}
            d["reason"] = "execute"
            d["status"] = "ok" if res.get("success") else "reject"
            d["data"] = {
                "sql": t.get("sql"),
                "success": res.get("success"),
                "rowcount": res.get("rowcount"),
                "error": res.get("error"),
            }
        elif stage == "intent_check":
            d["reason"] = t.get("reason", "ok")
            d["status"] = "ok" if t.get("ok") else "reject"
            d["data"] = {"ok": t.get("ok")}
        else:
            d["data"] = {k: v for k, v in t.items() if k != "stage"}

        decisions.append(d)
    return decisions


def format_decision_log(decisions: list[dict], max_items: int | None = 20) -> str:
    if not decisions:
        return "(no decisions logged)"
    out: list[str] = []
    lim = max_items if max_items is not None else len(decisions)
    for d in decisions[:lim]:
        out.append(f"[step {d.get('step')}] {d.get('decision')} - {d.get('reason')} ({d.get('status')})")
        data = d.get("data")
        if data:
            s = str(data)
            if len(s) > 320:
                s = s[:317] + "..."
            out.append(f"  data: {s}")
    return "\n".join(out)


def summarize_trace(trace: list[dict]) -> dict:
    actions = [t.get("stage") for t in trace if t.get("stage")]
    repairs = sum(1 for a in actions if a == "repair_sql")
    return {
        "actions": actions,
        "attempted_actions": [],
        "blocked_steps": 0,
        "repairs": repairs,
        "forced_repairs": repairs,
        "compliance_ok": True,
        "compliance_errors": [],
    }


## Implementation Notes

The loop is split into (1) setup + configuration, (2) helper utilities, and (3) the main `react_sql`
controller. This keeps the notebook readable while preserving a single linear execution flow.


In [None]:
def react_sql(
    *,
    nlq: str,
    schema_text: str | None = None,
    schema_summary: str | None = None,
    exemplars: list[dict] | None = None,
    max_steps: int = 8,
    debug: bool = False,
    debug_sleep_s: float = 0.0,
    debug_prompt_tail_lines: int = 0,
    debug_rows_preview: int = 3,
    auto_order: bool = False,
) -> tuple[str, list[dict], list[dict]]:
    cfg = _react_cfg(name="react_notebook")
    pred_sql, trace = run_react_pipeline(nlq=nlq, config=cfg)
    decisions = _trace_to_decisions(trace)

    if debug:
        print("ReAct pipeline (module-backed) trace:")
        for i, t in enumerate(trace):
            stage = t.get("stage")
            if not stage:
                continue
            if stage in ("generate_sql", "repair_sql"):
                raw = t.get("raw_sql")
                cleaned = t.get("sql_after_guardrails")
                if raw:
                    print(f"[step {i}] {stage} raw_sql: {' '.join(str(raw).split())[:220]}")
                if cleaned:
                    print(f"[step {i}] {stage} cleaned_sql: {' '.join(str(cleaned).split())[:220]}")
            elif stage in ("validate_sql", "validate_constraints"):
                res = t.get("result") or {}
                print(f"[step {i}] {stage}: valid={res.get('valid')} reason={res.get('reason')}")
            elif stage == "run_sql":
                res = t.get("result") or {}
                print(f"[step {i}] run_sql: success={res.get('success')} rowcount={res.get('rowcount')} error={res.get('error')}")
            elif stage == "intent_check":
                print(f"[step {i}] intent_check: ok={t.get('ok')} reason={t.get('reason')}")
            else:
                print(f"[step {i}] {stage}")

    return pred_sql, trace, decisions


### Quick sanity check


In [None]:
# 6a) Interactive walkthrough: type an NLQ and watch the loop step-by-step
DEMO_INTERACTIVE = True
DEMO_DEFAULT_NLQ = "Which customers are in France?"
DEMO_AUTO_ORDER = True  # keep the walkthrough linear (forces the next required step)
DEMO_SLEEP_S = 0.8  # set 0 for fast
DEMO_PROMPT_TAIL = 0  # set >0 to show the transcript tail the model sees
SHOW_DECISIONS = False

nlq = ""
if DEMO_INTERACTIVE:
    try:
        nlq = input("Type a ClassicModels question (blank uses default): ").strip()
    except Exception:
        nlq = ""
if not nlq:
    nlq = DEMO_DEFAULT_NLQ

pred, trace, decisions = react_sql(
    nlq=nlq,
    schema_summary=SCHEMA_SUMMARY,
    exemplars=REACT_EXEMPLARS,
    debug=True,
    auto_order=DEMO_AUTO_ORDER,
    debug_sleep_s=DEMO_SLEEP_S,
    debug_prompt_tail_lines=DEMO_PROMPT_TAIL,
)

print("\nFINAL SQL:")
print(pred)
print("\nTRACE SUMMARY:", summarize_trace(trace))
if SHOW_DECISIONS:
    print("\nDECISIONS:\n" + format_decision_log(decisions, max_items=40))


## Quick Sanity Check

A small, fast slice used to validate end-to-end behavior before longer runs. It surfaces
structural errors early and confirms the tool ordering and constraint checks behave as expected.


In [None]:
# 6) Quick sanity check on a few items (ReAct full loop)

from nl2sql.eval import execution_accuracy

DEBUG_EX = False
DEBUG_TRACE = False
DEBUG_TRACE_ON_MISMATCH = True
DEBUG_SQL_BUILD = True
DEBUG_SQL_MAX_CHARS = 220


def _normalize_sql_for_print(sql: str, max_chars: int = 220) -> str:
    s = " ".join(str(sql or "").split())
    if not s:
        return ""
    if max_chars and len(s) > max_chars:
        s = s[: max_chars - 3] + "..."
    return s


def _collect_sql_timeline(decisions: list[dict]) -> list[tuple[int, str, str]]:
    out: list[tuple[int, str, str]] = []
    for d in decisions:
        step = int(d.get("step", -1))
        dec = d.get("decision")
        data = d.get("data") or {}
        if dec in ("generate_sql", "repair_sql"):
            raw = data.get("raw_sql")
            if raw:
                out.append((step, "raw_sql", raw))
            cleaned = data.get("sql_after_guardrails")
            if cleaned:
                out.append((step, "cleaned_sql", cleaned))
        elif dec in ("validate_sql", "validate_constraints", "run_sql"):
            sql = data.get("sql")
            if sql:
                out.append((step, dec + "_sql", sql))
    return out


def _print_sql_timeline(items: list[tuple[int, str, str]], max_chars: int = 220) -> None:
    if not items:
        print("SQL BUILD: (no SQL events captured)")
        return
    print("SQL BUILD:")
    last = None
    for step, label, sql in items:
        s = _normalize_sql_for_print(sql, max_chars=max_chars)
        if not s or s == last:
            continue
        last = s
        print(f"  [step {step}] {label}: {s}")


for sample in test_set[:5]:
    nlq = sample["nlq"]
    gold = sample["sql"]
    pred, trace, decisions = react_sql(
        nlq=nlq,
        schema_summary=SCHEMA_SUMMARY,
        exemplars=REACT_EXEMPLARS,
        auto_order=True,
    )

    print("NLQ:", nlq)
    print("PRED:", pred)
    print("GOLD:", gold)

    if pred:
        meta = runner.run(pred, capture_df=False)
        print("VA:", int(meta.success), "ERR:", meta.error)
        ok, why = intent_constraints(nlq, pred)
        print("INTENT:", ok, why)
    else:
        print("VA:", 0, "ERR:", "no prediction")
        print("INTENT:", False, "no prediction")

    if DEBUG_EX and pred:
        ex_ok, pred_err, gold_err = execution_accuracy(engine=engine, pred_sql=pred, gold_sql=gold)
        print("EX:", int(ex_ok), "PRED_ERR:", pred_err, "GOLD_ERR:", gold_err)

    mismatch = (pred or "").strip().lower() != (gold or "").strip().lower()
    if trace and (DEBUG_TRACE or (DEBUG_TRACE_ON_MISMATCH and mismatch)):
        summary = summarize_trace(trace)
        print("TRACE LEN:", len(trace))
        print("EXECUTED ACTIONS:", summary.get("actions"))
        print("BLOCKED STEPS:", summary.get("blocked_steps"))
        print("COMPLIANCE:", summary.get("compliance_ok"), summary.get("compliance_errors"))
        print("TRACE SUMMARY:", summary)
        print("DECISIONS:\n" + format_decision_log(decisions, max_items=20))
        if DEBUG_SQL_BUILD:
            _print_sql_timeline(_collect_sql_timeline(decisions), max_chars=DEBUG_SQL_MAX_CHARS)
    else:
        print("TRACE LEN:", len(trace))

    print("-" * 80)


### Test Suite Accuracy (TS)

TS measures logical correctness by comparing execution outcomes across multiple databases.
It is stricter than string match and highlights schema-linking and join errors.


In [None]:
# === Test Suite Accuracy (TS) evaluation ===
# Harness now lives in nl2sql.eval for reuse in scripts.
from nl2sql.eval import test_suite_accuracy_for_item


### Quick toggles


In [None]:
# === Quick test toggles (set before full eval) ===
# Use small values to sanity-check TS/EX before full runs.

QUICK_LIMIT = 20
TS_N = 5
MAX_ROWS_TS = 500

# Canonical ablation toggles
USE_LINK_SCHEMA = True
USE_CONSTRAINT_POLICY = True
USE_REPAIR_POLICY = True
REACT_MAX_REPAIRS = 2
REACT_LINK_MAX_TABLES = 6


### Full Evaluation (VA/EX/EM/TS)

This run computes standard metrics on the full test set and records outputs for later analysis
and comparison between system variants.


In [None]:
# 7) Full agentic evaluation (VA/EX/EM/TS) over test_set

from pathlib import Path
import json
from functools import lru_cache
from sqlalchemy import Engine


def strip_trace_cycles(trace):
    # Canonical trace in react_pipeline is acyclic; keep helper for compatibility.
    return trace


TS_PREFIX = "classicmodels_ts"
SUITE_DBS = [f"{TS_PREFIX}_{i:02d}" for i in range(1, TS_N + 1)]


@lru_cache(maxsize=32)
def make_engine_cached(db_name: str) -> Engine:
    return make_engine(db_name)


def make_engine_fn(db_name: str) -> Engine:
    return make_engine_cached(db_name)


cfg = _react_cfg(name="react_eval")
report = evaluate_react_ablation(
    test_set=test_set,
    engine=engine,
    config=cfg,
    limit=QUICK_LIMIT,
    allow_extra_columns_ex=False,
    ts_suite_db_names=SUITE_DBS,
    ts_make_engine_fn=make_engine_fn,
    ts_max_rows=MAX_ROWS_TS,
    progress_every=20,
)

# Back-compat shape used by existing analysis cells.
results = []
for item in report["items"]:
    trace = strip_trace_cycles(item.get("trace") or [])
    decisions = _trace_to_decisions(trace)
    results.append(
        {
            "nlq": item["nlq"],
            "gold_sql": item["gold_sql"],
            "pred_sql": item["pred_sql"],
            "va": item["va"],
            "em": item["em"],
            "ex": item["ex"],
            "ts": item.get("ts"),
            "ts_debug": item.get("ts_debug"),
            "pred_err": item.get("error"),
            "gold_err": item.get("gold_error"),
            "trace": trace,
            "trace_summary": summarize_trace(trace),
            "decision_log": decisions,
        }
    )

va_rate = report["va_rate"]
em_rate = report["em_rate"]
ex_rate = report["ex_rate"]
ts_rate = report["ts_rate"] if report["ts_rate"] is not None else 0.0

print("ReAct VA:", round(va_rate, 3), "EX:", round(ex_rate, 3), "EM:", round(em_rate, 3), "TS:", round(ts_rate, 3))

out = {
    "va_rate": va_rate,
    "ex_rate": ex_rate,
    "em_rate": em_rate,
    "ts_rate": ts_rate,
    "items": results,
}
out_path = Path("results/agent/results_react_200.json")
out_path.write_text(json.dumps(out, indent=2, default=str))
print("Saved to", out_path)


In [None]:
# 7b) EX failure profiling (quick categories)
from collections import Counter

def _ex_reasons(decision_log):
    reasons = []
    for d in decision_log or []:
        r = d.get('reason')
        if r and r not in ('ok', 'success'):
            reasons.append(r)
    return reasons

def categorize_ex_failure(item):
    if not item.get('pred_sql'):
        return 'no_pred'
    if item.get('va') == 0:
        return 'invalid_sql'
    if item.get('ex') == 1:
        return 'correct'
    reasons = _ex_reasons(item.get('decision_log'))
    if any('missing_value_hint' in r for r in reasons):
        return 'missing_value_hint'
    if any('missing_location_table' in r for r in reasons):
        return 'missing_location_table'
    if any('missing_location_column' in r for r in reasons):
        return 'missing_location_column'
    if any(r.startswith('missing_agg') for r in reasons):
        return 'missing_agg'
    if any('missing_group_by' in r for r in reasons):
        return 'missing_group_by'
    if any('missing_order_by' in r for r in reasons):
        return 'missing_order_by'
    if any('missing_limit' in r for r in reasons):
        return 'missing_limit'
    return 'other'

counts = Counter(categorize_ex_failure(r) for r in results)
print('EX failure categories:')
for k, v in counts.most_common():
    print(f'  {k}: {v}')
