# Agentic Evaluation (Tool-Driven ReAct Loop)


### Setup (Colab only)


In [None]:

%%bash
set -e
export PIP_DEFAULT_TIMEOUT=120

# Clean conflicting preinstalls
pip uninstall -y torch torchvision torchaudio bitsandbytes triton transformers accelerate peft trl datasets numpy pandas fsspec requests google-auth || true

# Base deps
pip install -q --no-cache-dir --force-reinstall   numpy==1.26.4 pandas==2.2.1 fsspec==2024.5.0 requests==2.31.0 google-auth==2.43.0

# Torch + CUDA 12.1
pip install -q --no-cache-dir --force-reinstall   torch==2.3.1+cu121 torchvision==0.18.1+cu121 torchaudio==2.3.1+cu121   --index-url https://download.pytorch.org/whl/cu121

# bitsandbytes + triton + HF stack
pip install -q --no-cache-dir --force-reinstall   bitsandbytes==0.43.3 triton==2.3.1   transformers==4.44.2 accelerate==0.33.0 peft==0.17.0 trl==0.9.6 datasets==2.20.0

echo "Setup complete. Restart runtime once, then run the rest of the notebook top-to-bottom."


### Repo setup (Colab)


In [None]:
# 0) Clone repo (Colab) + install deps
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    if not os.path.exists('/content/NLtoSQL'):
        !git clone https://github.com/MacKenzieOBrian/NLtoSQL.git /content/NLtoSQL
    %cd /content/NLtoSQL
    !pip -q install -r requirements.txt
    import torch, transformers, accelerate, peft
    print('torch', torch.__version__, 'cuda', torch.cuda.is_available())
else:
    print('Not in Colab; using existing workspace')


### Optional: ADC auth


In [None]:
# Run this only if you prefer gcloud-based ADC (no JSON key)
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    %pip install -q --upgrade google-auth google-auth-oauthlib
    !gcloud auth application-default login
else:
    print("Not in Colab; skip gcloud auth.")


### Database Connection and Query Execution Layer

This block establishes a reproducible connection to ClassicModels and initializes the query runner
used for validation and execution accuracy. It ensures all downstream metrics are computed
against a consistent database state.


In [None]:
# 1) Environment + DB
from getpass import getpass

from sqlalchemy import text

from nl2sql.db import create_engine_with_connector, safe_connection

# Expected env vars (set these in a Colab cell):
# INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME
INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME") or "classicmodels"

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

# Canonical engine builder (shared with scripts + other notebooks).
# Uses Cloud SQL Connector under the hood and ADC for credentials.
engine, connector = create_engine_with_connector(
    instance_connection_name=INSTANCE_CONNECTION_NAME,
    user=DB_USER,
    password=DB_PASS,
    db_name=DB_NAME,
)

with safe_connection(engine) as conn:
    conn.execute(text("SELECT 1"))
print("DB connection OK")


### TS engine factory


In [None]:
# 1b) Engine factory for TS (multiple DB names)

import sqlalchemy
from sqlalchemy.engine import Engine


def make_engine(db_name: str) -> Engine:
    """Create a new engine bound to a specific TS replica DB name.

    TS (test-suite accuracy) executes the same (gold, pred) SQL across multiple
    replica databases (classicmodels_ts_XX). We keep separate engines so each
    replica is evaluated independently.
    """

    def getconn_for_db():
        return connector.connect(
            INSTANCE_CONNECTION_NAME,
            "pymysql",
            user=DB_USER,
            password=DB_PASS,
            db=db_name,
        )

    return sqlalchemy.create_engine("mysql+pymysql://", creator=getconn_for_db, future=True)


### Schema Summary and Test Set

We build a compact schema summary for prompt grounding and load the evaluation set used across runs.
This keeps the context small while preserving the tables and columns required for correct SQL synthesis.


In [None]:
# 2) Schema summary + test set + QueryRunner

QUICK_LIMIT = 20  # number of NLQs to evaluate (set None for full set)
import json
from pathlib import Path
from nl2sql.schema import build_schema_summary
from nl2sql.query_runner import QueryRunner

DB_NAME = globals().get("DB_NAME") or os.getenv("DB_NAME") or "classicmodels"
SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)
test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
test_set = full_set  # slice later via QUICK_LIMIT
print("Loaded test set size:", len(test_set))

# Small exemplar set (seeded to encourage join behavior).
join_exemplars = [it for it in full_set if "office" in it["nlq"].lower()]
REACT_EXEMPLARS = []
if join_exemplars:
    REACT_EXEMPLARS.append(join_exemplars[0])
for it in full_set:
    if it not in REACT_EXEMPLARS:
        REACT_EXEMPLARS.append(it)
    if len(REACT_EXEMPLARS) >= 3:
        break
print("Exemplars:", [e["nlq"] for e in REACT_EXEMPLARS])

runner = QueryRunner(engine)


### Load model


In [None]:

# 3) Load model (base or QLoRA adapters)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = os.getenv("ADAPTER_PATH") or "results/adapters/qlora_classicmodels"  # set to None to use base model

HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    HF_TOKEN = getpass("Enter HF_TOKEN (https://huggingface.co/settings/tokens): ").strip()

cc_major, cc_minor = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else (0, 0)
use_bf16 = cc_major >= 8
compute_dtype = torch.bfloat16 if use_bf16 else torch.float16
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Using bf16:", use_bf16)
print("Adapter path:", ADAPTER_PATH)

# Tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Quantized base model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
    device_map={"": 0} if torch.cuda.is_available() else None,
    token=HF_TOKEN,
)
base_model.generation_config.do_sample = False
base_model.generation_config.temperature = 1.0
base_model.generation_config.top_p = 1.0

# Load adapters if present locally; otherwise use base model
adapter_dir = Path(ADAPTER_PATH) if ADAPTER_PATH else None
if adapter_dir and adapter_dir.exists():
    model = PeftModel.from_pretrained(base_model, adapter_dir, token=HF_TOKEN)
    print("Loaded adapters from", adapter_dir)
else:
    print("Adapter path missing; using base model only. Set ADAPTER_PATH to your local adapter folder or upload it to Colab.")
    model = base_model


### Optional: smoke check


In [None]:
from nl2sql.prompting import make_few_shot_messages
from nl2sql.llm import extract_first_select
from nl2sql.postprocess import guarded_postprocess
from nl2sql.eval import execution_accuracy
from nl2sql.constraint_policy import build_constraints
import random

runner_check = QueryRunner(engine)
exemplars = test_set[:3]


def run_quick_check(k: int = 0, limit: int = 3, seed: int = 7) -> None:
    """Small smoke test before full runs."""
    rng = random.Random(seed)
    pool = test_set if limit is None else rng.sample(test_set, k=min(limit, len(test_set)))
    print(f"Quick check | k={k} | n={len(pool)}")

    for sample in pool:
        shots = exemplars if k > 0 else []
        messages = make_few_shot_messages(
            schema=SCHEMA_SUMMARY,
            exemplars=shots,
            nlq=sample['nlq'],
        )
        prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tok(prompt, return_tensors='pt').to(model.device)
        out = model.generate(**inputs, max_new_tokens=256, do_sample=False)

        gen_ids = out[0][inputs.input_ids.shape[-1]:]
        text = tok.decode(gen_ids, skip_special_tokens=True)

        raw_sql = extract_first_select(text) or text
        constraints = build_constraints(sample['nlq'], SCHEMA_SUMMARY)
        explicit_fields = constraints.get('explicit_fields')
        explicit_projection = constraints.get('explicit_projection')
        required_fields = constraints.get('required_output_fields')

        sql = guarded_postprocess(
            raw_sql,
            sample['nlq'],
            explicit_fields=explicit_fields if explicit_projection else None,
            required_fields=required_fields,
        )

        meta = runner_check.run(sql, capture_df=False)
        ex_ok, _, _ = execution_accuracy(engine=engine, pred_sql=sql, gold_sql=sample['sql'])

        print(f"Q: {sample['nlq']}")
        print(f"SQL: {sql}")
        print(f"VA: {int(meta.success)} EX: {int(ex_ok)}")
        if meta.error:
            print(f"ERR: {meta.error}")
        print('-' * 80)


run_quick_check(k=0, limit=3)
run_quick_check(k=3, limit=3)


### Guardrails and Post-Processing Utilities

Guardrails provide lightweight correctness constraints (SELECT-only, schema checks, and
intent checks) that prevent "runs but wrong" SQL. Projection cleanup is conservative:
explicit projection is enforced only when the NLQ enumerates fields (or a high-precision
template requires it), and minimal-projection pruning is skipped when constraints require
specific output fields. This keeps post-processing from changing query semantics while
still reducing common formatting noise.


In [None]:
# 4) Agent utilities + guardrails
import importlib
import nl2sql.agent_utils as agent_utils
import nl2sql.postprocess as postprocess
import nl2sql.eval as eval_mod
importlib.reload(agent_utils)
importlib.reload(postprocess)
importlib.reload(eval_mod)

from nl2sql.agent_utils import (
    _extract_required_columns,
    intent_constraints,
    classify_intent,
    clean_candidate_with_reason,
)
from nl2sql.postprocess import guarded_postprocess
from nl2sql.eval import execution_accuracy


## Tool-Driven ReAct (Interactive Evaluation Harness)

This section is refactored as an experiment control surface so you can quickly test code changes,
inspect step-by-step traces, and run reproducible metric evaluations from one place.


### Why These Tools Exist (and Where They Come From)

The loop uses explicit actions defined in `nl2sql/agent_tools.py` and orchestrated by
`nl2sql/react_pipeline.py`.

Design motivation:
- ReAct action-observation control (`REFERENCES.md#ref-yao2023-react`)
- Execution-guided repair and rejection (`REFERENCES.md#ref-wang2018-eg-decoding`)
- Constraint-first validation (`REFERENCES.md#ref-scholak2021-picard`)
- Schema/value linking pressure points (`REFERENCES.md#ref-wang2020-ratsql`, `REFERENCES.md#ref-li2023-resdsql`, `REFERENCES.md#ref-lin2020-bridge`)


### Interactive Controls


In [None]:
# 5a) Runtime controls + module binding

RELOAD_REACT_MODULES = True
RUN_NAME = 'react_core_notebook'

# Core loop toggles (keep minimal by default)
USE_LINK_SCHEMA = True
USE_CONSTRAINT_POLICY = True
USE_REPAIR_POLICY = True
USE_INTENT_GATE = False
STOP_ON_FIRST_SUCCESS = True
REACT_MAX_REPAIRS = 2
REACT_LINK_MAX_TABLES = 6

# Single-question walkthrough controls
DEMO_INTERACTIVE = False  # set True to type your own NLQ
DEMO_DEFAULT_NLQ = 'Which customers are in France?'
DEMO_DEBUG = True
DEMO_SHOW_DECISIONS = True

# Batch smoke controls
BATCH_SAMPLE_SIZE = 5
BATCH_RANDOM_SEED = 7
BATCH_DEBUG_ON_MISMATCH = True
DEBUG_SQL_MAX_CHARS = 220

# Full evaluation controls
EVAL_LIMIT = QUICK_LIMIT if 'QUICK_LIMIT' in globals() else 20
TS_N = 5
TS_PREFIX = 'classicmodels_ts'
MAX_ROWS_TS = 500
EVAL_OUTPUT_PATH = 'results/agent/results_react_200.json'

import importlib
import nl2sql.agent_tools as agent_tools
import nl2sql.react_pipeline as react_pipeline

if RELOAD_REACT_MODULES:
    importlib.reload(agent_tools)
    importlib.reload(react_pipeline)

from nl2sql.agent_tools import AgentContext, set_agent_context
from nl2sql.react_pipeline import (
    ReactAblationConfig,
    core_react_config,
    run_react_pipeline,
    evaluate_react_ablation,
)

set_agent_context(
    AgentContext(
        engine=engine,
        db_name=DB_NAME,
        model=model,
        tok=tok,
        runner=runner,
        max_new_tokens=128,
    )
)

print('ReAct context bound.')
print('Run name:', RUN_NAME)
print('Eval limit:', EVAL_LIMIT, '| TS replicas:', TS_N)


### Tool Justification Matrix

Use this table in viva/write-up to explain each tool in one sentence.


In [None]:
# 5b) Tool justifications (purpose + code source + research anchor)

import pandas as pd
from IPython.display import display

tool_rows = [
    {
        'tool': 'get_schema',
        'why_used': 'Ground generation in real tables/columns and avoid hallucinated schema.',
        'source': 'nl2sql/agent_tools.py:get_schema',
        'research_anchor': 'Schema-grounded parsing; Spider setup (ref-yu2018-spider)',
    },
    {
        'tool': 'link_schema',
        'why_used': 'Reduce prompt scope to relevant tables before generation.',
        'source': 'nl2sql/agent_tools.py:link_schema',
        'research_anchor': 'RAT-SQL / RESDSQL schema linking (ref-wang2020-ratsql, ref-li2023-resdsql)',
    },
    {
        'tool': 'extract_constraints',
        'why_used': 'Convert NL intent into structural checks (agg/order/limit/projection hints).',
        'source': 'nl2sql/agent_tools.py:extract_constraints',
        'research_anchor': 'Constraint-first decoding and validation (ref-scholak2021-picard)',
    },
    {
        'tool': 'generate_sql',
        'why_used': 'Produce initial SQL candidate from NLQ + focused schema + constraints.',
        'source': 'nl2sql/agent_tools.py:generate_sql',
        'research_anchor': 'LLM generation baseline for text-to-SQL (ref-zhu2024-survey)',
    },
    {
        'tool': 'validate_sql',
        'why_used': 'Block invalid schema references and malformed SQL before execution.',
        'source': 'nl2sql/agent_tools.py:validate_sql',
        'research_anchor': 'Validity gating before runtime execution (ref-scholak2021-picard)',
    },
    {
        'tool': 'validate_constraints',
        'why_used': 'Catch "runs-but-wrong-shape" SQL early (missing grouping/order/fields).',
        'source': 'nl2sql/agent_tools.py:validate_constraints',
        'research_anchor': 'Semantic structure checks for execution reliability (ref-zhai2025-excot)',
    },
    {
        'tool': 'run_sql',
        'why_used': 'Provide execution observation used for VA and repair triggers.',
        'source': 'nl2sql/agent_tools.py:run_sql',
        'research_anchor': 'Execution-guided feedback (ref-wang2018-eg-decoding)',
    },
    {
        'tool': 'repair_sql',
        'why_used': 'Revise SQL only after concrete validation/execution failure evidence.',
        'source': 'nl2sql/agent_tools.py:repair_sql',
        'research_anchor': 'Feedback-driven correction loops (ref-yao2023-react, ref-zhai2025-excot)',
    },
]

tool_justification_df = pd.DataFrame(tool_rows)
display(tool_justification_df)


### ReAct Wrapper and Trace Helpers

These helpers make experiments repeatable and easier to inspect after each code change.


In [None]:
# 5c) Notebook wrapper around module-level ReAct

from nl2sql.postprocess import normalize_sql


def _react_cfg(name: str | None = None) -> ReactAblationConfig:
    base = core_react_config(name=name or RUN_NAME)
    return ReactAblationConfig(
        name=base.name,
        use_schema_link=USE_LINK_SCHEMA,
        use_constraint_policy=USE_CONSTRAINT_POLICY,
        use_repair_policy=USE_REPAIR_POLICY,
        use_intent_gate=USE_INTENT_GATE,
        stop_on_first_success=STOP_ON_FIRST_SUCCESS,
        max_repairs=REACT_MAX_REPAIRS,
        link_max_tables=REACT_LINK_MAX_TABLES,
    )


def _trace_to_decisions(trace: list[dict]) -> list[dict]:
    decisions = []
    for i, t in enumerate(trace or []):
        stage = t.get('stage')
        if not stage:
            continue
        entry = {'step': i, 'decision': stage, 'status': 'ok', 'reason': 'ok', 'data': {}}

        if stage in ('validate_sql', 'validate_constraints'):
            res = t.get('result') or {}
            entry['status'] = 'ok' if res.get('valid') else 'reject'
            entry['reason'] = res.get('reason', 'ok')
            entry['data'] = dict(res)
            if t.get('sql'):
                entry['data']['sql'] = t.get('sql')
        elif stage == 'run_sql':
            res = t.get('result') or {}
            entry['status'] = 'ok' if res.get('success') else 'reject'
            entry['reason'] = 'execution'
            entry['data'] = {
                'sql': t.get('sql'),
                'success': res.get('success'),
                'rowcount': res.get('rowcount'),
                'error': res.get('error'),
            }
        elif stage == 'repair_sql':
            entry['reason'] = t.get('reason', 'repair')
            entry['data'] = {
                'raw_sql': t.get('raw_sql'),
                'cleaned_sql': t.get('sql_after_guardrails'),
                'repair_count': t.get('repair_count'),
            }
        else:
            entry['data'] = {k: v for k, v in t.items() if k != 'stage'}

        decisions.append(entry)
    return decisions


def summarize_trace(trace: list[dict]) -> dict:
    actions = [t.get('stage') for t in (trace or []) if t.get('stage')]
    repairs = sum(1 for a in actions if a == 'repair_sql')
    return {
        'actions': actions,
        'repairs': repairs,
        'first_success_stop': STOP_ON_FIRST_SUCCESS,
        'intent_gate': USE_INTENT_GATE,
    }


def format_decision_log(decisions: list[dict], max_items: int = 25) -> str:
    if not decisions:
        return '(no decisions)'
    out = []
    for d in decisions[:max_items]:
        out.append(f"[step {d['step']}] {d['decision']} -> {d['status']} ({d['reason']})")
        if d.get('data'):
            text = str(d['data'])
            if len(text) > 320:
                text = text[:317] + '...'
            out.append(f'  data: {text}')
    return '\n'.join(out)


def react_sql(nlq: str, debug: bool = False) -> tuple[str, list[dict], list[dict]]:
    cfg = _react_cfg()
    pred_sql, trace = run_react_pipeline(nlq=nlq, config=cfg)
    decisions = _trace_to_decisions(trace)

    if debug:
        print('Trace summary:', summarize_trace(trace))
        print(format_decision_log(decisions, max_items=40))

    return pred_sql, trace, decisions


## Interactive Walkthrough (Single NLQ)

Use this cell while developing new guardrails/constraints to inspect behavior step-by-step.


In [None]:
# 6a) Single-question interactive walkthrough

from nl2sql.constraint_policy import build_constraints

nlq = ''
if DEMO_INTERACTIVE:
    try:
        nlq = input('Type a ClassicModels question (blank uses default): ').strip()
    except Exception:
        nlq = ''
if not nlq:
    nlq = DEMO_DEFAULT_NLQ

print('NLQ:', nlq)
pred, trace, decisions = react_sql(nlq=nlq, debug=DEMO_DEBUG)

print('\nFINAL SQL:')
print(pred or '(no prediction)')
print('\nTRACE SUMMARY:', summarize_trace(trace))

if pred:
    meta = runner.run(pred, capture_df=False)
    print('VA:', int(meta.success), '| Error:', meta.error)
else:
    print('VA: 0 | Error: no prediction')

constraints = build_constraints(nlq, SCHEMA_SUMMARY)
interesting_keys = [
    'explicit_projection',
    'required_output_fields',
    'required_tables',
    'needs_location',
    'value_hints',
]
print('Constraints:', {k: constraints.get(k) for k in interesting_keys})

if DEMO_SHOW_DECISIONS:
    print('\nDECISIONS:\n' + format_decision_log(decisions, max_items=50))


## Batch Smoke Test (Fast Iteration)

Run this after each pipeline change before full evaluation.


In [None]:
# 6b) Batch smoke test on random items

import random
from collections import Counter
import pandas as pd

from nl2sql.constraint_policy import build_constraints
from nl2sql.validation import validate_constraints
from nl2sql.postprocess import normalize_sql
from nl2sql.eval import execution_accuracy


def _short_sql(sql: str | None, max_chars: int = 220) -> str:
    s = ' '.join(str(sql or '').split())
    if max_chars and len(s) > max_chars:
        s = s[:max_chars - 3] + '...'
    return s


rng = random.Random(BATCH_RANDOM_SEED)
samples = rng.sample(test_set, k=min(BATCH_SAMPLE_SIZE, len(test_set)))
rows = []

for sample in samples:
    nlq = sample['nlq']
    gold = sample['sql']
    pred, trace, decisions = react_sql(nlq=nlq, debug=False)

    if pred:
        meta = runner.run(pred, capture_df=False)
        va = int(meta.success)
        ex_ok, pred_err, gold_err = execution_accuracy(
            engine=engine,
            pred_sql=pred,
            gold_sql=gold,
            allow_extra_columns=False,
        )
        ex = int(ex_ok)
        error = meta.error or pred_err
    else:
        va = 0
        ex = 0
        error = 'no_prediction'
        gold_err = None

    em = int(normalize_sql(pred or '') == normalize_sql(gold or ''))

    constraints = build_constraints(nlq, SCHEMA_SUMMARY)
    v_constraints = (
        validate_constraints(pred, constraints, schema_text=SCHEMA_SUMMARY)
        if pred
        else {'valid': False, 'reason': 'no_prediction'}
    )

    rows.append(
        {
            'nlq': nlq,
            'va': va,
            'em': em,
            'ex': ex,
            'pred_sql': _short_sql(pred, DEBUG_SQL_MAX_CHARS),
            'gold_sql': _short_sql(gold, DEBUG_SQL_MAX_CHARS),
            'error': error,
            'constraint_valid': int(bool(v_constraints.get('valid'))),
            'constraint_reason': v_constraints.get('reason'),
            'repairs': summarize_trace(trace).get('repairs', 0),
            'decision_log': decisions,
            'trace_summary': summarize_trace(trace),
            'gold_error': gold_err,
        }
    )

smoke_df = pd.DataFrame(rows)
display(smoke_df[['nlq', 'va', 'em', 'ex', 'constraint_valid', 'repairs', 'error']])

print('Smoke metrics:')
print({
    'n': len(smoke_df),
    'va_rate': round(float(smoke_df['va'].mean()), 3),
    'em_rate': round(float(smoke_df['em'].mean()), 3),
    'ex_rate': round(float(smoke_df['ex'].mean()), 3),
})

if BATCH_DEBUG_ON_MISMATCH:
    mismatches = smoke_df[(smoke_df['ex'] == 0) | (smoke_df['va'] == 0)]
    for _, row in mismatches.iterrows():
        print('\nNLQ:', row['nlq'])
        print('PRED:', row['pred_sql'])
        print('GOLD:', row['gold_sql'])
        print('ERR:', row['error'])
        print('DECISIONS:\n' + format_decision_log(row['decision_log'], max_items=25))
        print('-' * 80)


### Test Suite Accuracy (TS)

TS checks semantic consistency across perturbed database replicas and reduces lucky single-DB matches.


In [None]:
# 6c) TS harness
from nl2sql.eval import test_suite_accuracy_for_item


### Full Evaluation Controls


In [None]:
# 7a) Full evaluation runtime config

from functools import lru_cache
from pathlib import Path
import subprocess

from sqlalchemy.engine import Engine


def _git_commit_short() -> str | None:
    try:
        return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], text=True).strip()
    except Exception:
        return None


TS_PREFIX = TS_PREFIX if 'TS_PREFIX' in globals() else 'classicmodels_ts'
SUITE_DBS = [f"{TS_PREFIX}_{i:02d}" for i in range(1, TS_N + 1)]


@lru_cache(maxsize=32)
def make_engine_cached(db_name: str) -> Engine:
    return make_engine(db_name)


def make_engine_fn(db_name: str) -> Engine:
    return make_engine_cached(db_name)


RUN_METADATA = {
    'notebook': '03_agentic_eval.ipynb',
    'model_id': MODEL_ID,
    'adapter_path': ADAPTER_PATH,
    'commit': _git_commit_short(),
    'run_name': RUN_NAME,
}

print('Evaluation config:')
print('  limit:', EVAL_LIMIT)
print('  ts_dbs:', SUITE_DBS)
print('  output:', EVAL_OUTPUT_PATH)
print('  metadata:', RUN_METADATA)


### Full Evaluation (VA/EX/EM/TS)

This run is the reproducible report path used for dissertation tables.


In [None]:
# 7b) Full evaluation run

import json

cfg = _react_cfg(name=RUN_NAME)
report = evaluate_react_ablation(
    test_set=test_set,
    engine=engine,
    config=cfg,
    limit=EVAL_LIMIT,
    allow_extra_columns_ex=False,
    ts_suite_db_names=SUITE_DBS,
    ts_make_engine_fn=make_engine_fn,
    ts_max_rows=MAX_ROWS_TS,
    progress_every=20,
    seed=BATCH_RANDOM_SEED,
    run_metadata=RUN_METADATA,
)

# Backward-compatible shape for existing analysis scripts/cells.
results = []
for item in report['items']:
    trace = item.get('trace') or []
    decisions = _trace_to_decisions(trace)
    results.append(
        {
            'nlq': item['nlq'],
            'gold_sql': item['gold_sql'],
            'pred_sql': item['pred_sql'],
            'va': item['va'],
            'em': item['em'],
            'ex': item['ex'],
            'ts': item.get('ts'),
            'ts_debug': item.get('ts_debug'),
            'pred_err': item.get('error'),
            'gold_err': item.get('gold_error'),
            'trace': trace,
            'trace_summary': summarize_trace(trace),
            'decision_log': decisions,
        }
    )

out = {
    **report,
    'items': results,
}

out_path = Path(EVAL_OUTPUT_PATH)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2, default=str), encoding='utf-8')

ts_rate = report['ts_rate'] if report.get('ts_rate') is not None else 0.0
print(
    f"{RUN_NAME} | n={report['n']} | "
    f"VA={report['va_rate']:.3f} EM={report['em_rate']:.3f} "
    f"EX={report['ex_rate']:.3f} TS={ts_rate:.3f}"
)
print('Saved to', out_path)


In [None]:
# 7c) EX failure profiling (quick categories)

from collections import Counter
import pandas as pd


def _decision_reasons(decision_log: list[dict]) -> list[str]:
    reasons = []
    for d in decision_log or []:
        r = d.get('reason')
        if r and r not in ('ok', 'execution', 'stop_on_first_success'):
            reasons.append(r)
    return reasons


def categorize_ex_failure(item: dict) -> str:
    if not item.get('pred_sql'):
        return 'no_prediction'
    if int(item.get('va', 0)) == 0:
        return 'invalid_sql'
    if int(item.get('ex', 0)) == 1:
        return 'correct'

    reasons = _decision_reasons(item.get('decision_log') or [])
    if any('missing_value_hint' in r for r in reasons):
        return 'value_linking'
    if any('missing_location' in r for r in reasons):
        return 'location_linking'
    if any('missing_agg' in r for r in reasons):
        return 'aggregation'
    if any('missing_group_by' in r for r in reasons):
        return 'grouping'
    if any('missing_order_by' in r or 'missing_limit' in r for r in reasons):
        return 'ordering_limit'
    if any('join' in r for r in reasons):
        return 'join_path'
    return 'other_semantic'


failure_counts = Counter(categorize_ex_failure(r) for r in results)
failure_df = pd.DataFrame(
    [{'category': k, 'count': v} for k, v in failure_counts.items()]
).sort_values('count', ascending=False)

display(failure_df)
