# Agentic Evaluation (ReAct-style)

This notebook adds a minimal ReAct-style loop for NL→SQL. It reuses the same benchmark (`data/classicmodels_test_200.json`) and metrics (VA/EX/EM; TS planned) to measure gains over prompt-only and QLoRA runs.

Plan (step-by-step):
1) Clone repo (Colab) + install deps
2) Environment + DB connection
3) Load schema summary + test set
4) Load model (base or QLoRA adapters)
5) Define ReAct prompt + loop (Thought → Action → Observation → Refinement)
6) Run evaluation (VA/EX/EM) and save to `results/agent/…`


Docs I leaned on: HF Transformers quantization (https://huggingface.co/docs/transformers/main_classes/quantization), PEFT/TRL (https://huggingface.co/docs/peft/, https://huggingface.co/docs/trl/), Cloud SQL connector + SQLAlchemy creator (https://cloud.google.com/sql/docs/mysql/connect-run, https://docs.sqlalchemy.org/en/20/core/engines.html#custom-dbapi-connect), ReAct (https://arxiv.org/abs/2210.03629).

## Setup (run first, then restart)
In a fresh Colab GPU runtime, run this one cell to clean preinstalls and pin the CUDA 12.1 torch/bitsandbytes/triton stack. When it finishes, **Runtime → Restart runtime**, then run the rest of the notebook from the clone cell onward without more restarts.

**Docs (setup):** HF Transformers quantization + BitsAndBytes (4-bit) https://huggingface.co/docs/transformers/main_classes/quantization, bnb https://github.com/TimDettmers/bitsandbytes.

In [None]:

%%bash
set -e
export PIP_DEFAULT_TIMEOUT=120

# Clean conflicting preinstalls
pip uninstall -y torch torchvision torchaudio bitsandbytes triton transformers accelerate peft trl datasets numpy pandas fsspec requests google-auth || true

# Base deps
pip install -q --no-cache-dir --force-reinstall   numpy==1.26.4 pandas==2.2.1 fsspec==2024.5.0 requests==2.31.0 google-auth==2.43.0

# Torch + CUDA 12.1
pip install -q --no-cache-dir --force-reinstall   torch==2.3.1+cu121 torchvision==0.18.1+cu121 torchaudio==2.3.1+cu121   --index-url https://download.pytorch.org/whl/cu121

# bitsandbytes + triton + HF stack
pip install -q --no-cache-dir --force-reinstall   bitsandbytes==0.43.3 triton==2.3.1   transformers==4.44.2 accelerate==0.33.0 peft==0.17.0 trl==0.9.6 datasets==2.20.0

echo "Setup complete. Restart runtime once, then run the rest of the notebook top-to-bottom."


Model load: HF 4-bit NF4 + BitsAndBytes; deterministic decoding. If adapters exist, we load them.

In [None]:
# 0) Clone repo (Colab) + install deps
import os
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    if not os.path.exists('/content/NLtoSQL'):
        !git clone https://github.com/MacKenzieOBrian/NLtoSQL.git /content/NLtoSQL
    %cd /content/NLtoSQL
    !pip -q install -r requirements.txt
    import torch, transformers, accelerate, peft
    print('torch', torch.__version__, 'cuda', torch.cuda.is_available())
else:
    print('Not in Colab; using existing workspace')


Prompt/eval: build prompts (system+schema+k exemplars), generate SQL, postprocess, and compute VA/EX/EM.

**Ref:** Colab clone/install pattern; keeps notebooks thin and code in `nl2sql/`. Hugging Face/Colab standard workflow.

### Reference notes (what this builds on)
- DB access: Cloud SQL Connector + SQLAlchemy creator (GCP docs: https://cloud.google.com/sql/docs/mysql/connect-run) for secure pooled ClassicModels access.
- Schema/prompting: uses repo helpers (`nl2sql.schema`, `prompting`) aligned with schema-grounded NL→SQL prompting (survey: https://arxiv.org/abs/2410.06011).
- Model load: HF Transformers 4-bit NF4 with BitsAndBytes (quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization), same pattern as QLoRA.
- Agent loop: ReAct-style Thought→Action→Observation→Refinement, inspired by Yao et al. 2023 (https://arxiv.org/abs/2210.03629) and agentic NL→SQL in Ojuri et al. 2025.
- Eval: repo harness (`nl2sql.eval`, `QueryRunner`) for VA/EX/EM; TS planned.


## Optional: use gcloud ADC (without a key)

**Ref:** GCP ADC flow (docs: https://cloud.google.com/docs/authentication/provide-credentials-adc). Optional fallback if no service account JSON.

In [None]:
# Run this only if you prefer gcloud-based ADC (no JSON key)
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    %pip install -q --upgrade google-auth google-auth-oauthlib
    !gcloud auth application-default login
else:
    print("Not in Colab; skip gcloud auth.")


**Ref:** Pinned CUDA12.1 torch/bitsandbytes/triton stack per HF/BnB guidance for 4-bit loads on Colab GPUs.

**Ref:** Cloud SQL Connector + SQLAlchemy creator (GCP MySQL docs: https://cloud.google.com/sql/docs/mysql/connect-run) for secure ClassicModels access.

**Docs (auth/DB):** Cloud SQL connector pattern https://cloud.google.com/sql/docs/mysql/connect-run; SQLAlchemy creator hook https://docs.sqlalchemy.org/en/20/core/engines.html#custom-dbapi-connect.

In [None]:

# 1) Environment + DB
import os
from getpass import getpass
from pathlib import Path

from google.cloud.sql.connector import Connector
from google.oauth2.service_account import Credentials
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine

# Expected env vars (set these in a Colab cell):
# GOOGLE_APPLICATION_CREDENTIALS=/content/sa.json
# INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME
INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME") or "classicmodels"
GOOGLE_CREDS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

creds = None
if GOOGLE_CREDS:
    creds = Credentials.from_service_account_file(GOOGLE_CREDS)
    print(f"Using service account from {GOOGLE_CREDS}")
else:
    print("Using default ADC (gcloud auth or Colab auth). If this fails, set GOOGLE_APPLICATION_CREDENTIALS.")

connector = Connector(credentials=creds)

def getconn():
    return connector.connect(
        INSTANCE_CONNECTION_NAME,
        "pymysql",
        user=DB_USER,
        password=DB_PASS,
        db=DB_NAME,
    )

engine: Engine = create_engine(
    "mysql+pymysql://",
    creator=getconn,
    future=True,
)

with engine.connect() as conn:
    conn.execute(text("SELECT 1"))
print("DB connection OK")


**Ref:** Schema helper in `nl2sql.schema`; schema-grounded prompting per NL→SQL survey (https://arxiv.org/abs/2410.06011).

**Docs (schema prompts):** NL→SQL schema-grounded prompting survey https://arxiv.org/abs/2410.06011; Spider-style listings.

In [None]:
# 2) Load schema summary + test set (small slice for now)
import json
from nl2sql.schema import build_schema_summary

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)

test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
# default to a small slice while debugging
test_set = full_set[:5]
print("Demo items:", len(test_set))
# For full run, switch to: test_set = full_set; print("Test items:", len(test_set))

TABLES = {line.split('(', 1)[0].strip() for line in SCHEMA_SUMMARY.splitlines() if '(' in line}
TABLES_LOWER = {t.lower(): t for t in TABLES}


**Ref:** HF Transformers 4-bit NF4 + BitsAndBytes (quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization); adapters via PEFT.

**Docs (model load):** HF 4-bit NF4 quantization https://huggingface.co/docs/transformers/main_classes/quantization; PEFT/QLoRA https://huggingface.co/docs/peft/.

In [None]:

# 3) Load model (base or QLoRA adapters)
import os
from getpass import getpass
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = os.getenv("ADAPTER_PATH") or "results/adapters/qlora_classicmodels"  # set to None to use base model

HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    HF_TOKEN = getpass("Enter HF_TOKEN (https://huggingface.co/settings/tokens): ").strip()

cc_major, cc_minor = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else (0, 0)
use_bf16 = cc_major >= 8
compute_dtype = torch.bfloat16 if use_bf16 else torch.float16
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Using bf16:", use_bf16)
print("Adapter path:", ADAPTER_PATH)

# Tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Quantized base model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
    device_map={"": 0} if torch.cuda.is_available() else None,
    token=HF_TOKEN,
)
base_model.generation_config.do_sample = False
base_model.generation_config.temperature = 1.0
base_model.generation_config.top_p = 1.0

# Load adapters if present locally; otherwise use base model
adapter_dir = Path(ADAPTER_PATH) if ADAPTER_PATH else None
if adapter_dir and adapter_dir.exists():
    model = PeftModel.from_pretrained(base_model, adapter_dir, token=HF_TOKEN)
    print("Loaded adapters from", adapter_dir)
else:
    print("Adapter path missing; using base model only. Set ADAPTER_PATH to your local adapter folder or upload it to Colab.")
    model = base_model


## Optional adapter sanity check (run before ReAct)
Quick check to see if the loaded model/adapters produce valid SQL on a tiny slice. Uses the prompt harness (k=0/k=3) and executes the SQL to report VA/EX.

**Docs (prompt/eval):** ICL patterns https://arxiv.org/abs/2005.14165; execution-based metrics (VA/EX) https://aclanthology.org/2020.emnlp-main.29/.

In [None]:
from nl2sql.prompting import make_few_shot_messages
from nl2sql.llm import extract_first_select
from nl2sql.postprocess import guarded_postprocess
from nl2sql.query_runner import QueryRunner
from nl2sql.eval import execution_accuracy

runner_check = QueryRunner(engine)
# reuse existing test_set (default small slice); pick 3 exemplars
exemplars = test_set[:3]

def run_quick_check(k: int = 0, limit: int = 3):
    print(f"Quick check k={k}")
    for sample in test_set[:limit]:
        shots = exemplars if k > 0 else []
        msgs = make_few_shot_messages(
            schema=SCHEMA_SUMMARY,
            exemplars=shots,
            nlq=sample['nlq'],
        )
        prompt_preview = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        inputs = tok(prompt_preview, return_tensors="pt").to(model.device)
        out = model.generate(**inputs, max_new_tokens=256, do_sample=False)

        # strip the prompt before decoding the generation
        gen_ids = out[0][inputs.input_ids.shape[-1]:]
        text = tok.decode(gen_ids, skip_special_tokens=True)

        raw_sql = extract_first_select(text) or text
        sql = guarded_postprocess(raw_sql, sample['nlq'])

        meta = runner_check.run(sql, capture_df=False)
        va = meta.success
        ex_ok, _, _ = execution_accuracy(engine=engine, pred_sql=sql, gold_sql=sample['sql'])
        print(f"Q: {sample['nlq']}
SQL: {sql}
VA: {va} EX: {ex_ok}
")

run_quick_check(k=0)
run_quick_check(k=3)


**Ref:** ReAct pattern (Yao et al. 2023: https://arxiv.org/abs/2210.03629) adapted for NL→SQL with `QueryRunner` as the Act step.

**Docs (ReAct):** ReAct loop (Yao et al. 2023) https://arxiv.org/abs/2210.03629; safe Act via SELECT-only executor.

In [None]:

# Helper imports (ReAct helpers live in nl2sql/agent_utils)
from nl2sql.agent_utils import (
    clean_candidate,
    build_tabular_prompt,
    vanilla_candidate,
    classify_error,
    error_hint,
    semantic_score,
    count_select_columns,
)


### Agent status (for dissertation)
Current loop = execution-guided reranker: sampled candidates, SELECT-only filter, semantic rerank, error-classified repair, deterministic few-shot fallback.
Not yet full ReAct: we don’t enforce structured `Thought / Action: SCHEMA_LOOKUP[...] / Action: EXEC_SQL[...] / Observation: ... / FINISH[...]`, so the model isn’t forced to read and react to its own Observations.
Planned upgrade (if time permits): add an explicit tool grammar and feed Observations back into the prompt so the model can revise after execution errors (Yao et al., 2023).


**Ref:** Repo eval (`nl2sql.eval`) for VA/EX/EM; execution-based metrics align with Ojuri et al. 2025 and EMNLP’20 TS.

**Docs (prompt/eval):** ICL patterns https://arxiv.org/abs/2005.14165; execution-based metrics (VA/EX) https://aclanthology.org/2020.emnlp-main.29/.

## ReAct execution-guided pipeline (best version so far)
These cells mirror the committed helper layer (`nl2sql/agent_utils.py`) and set up the current execution-guided reranker + evaluation harness.

In [None]:
# 4) Schema summary + test set + QueryRunner
import json
from pathlib import Path
from nl2sql.schema import build_schema_summary
from nl2sql.query_runner import QueryRunner

DB_NAME = os.getenv("DB_NAME") or "classicmodels"

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)
test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
test_set = full_set  # change to full_set[:20] when debugging

print("Loaded test set size:", len(test_set))
runner = QueryRunner(engine)

In [None]:
# 5) Agent utilities: semantic reranker, baseline candidate, error taxonomy
from nl2sql.agent_utils import (
    clean_candidate,
    build_tabular_prompt,
    vanilla_candidate,
    classify_error,
    error_hint,
    semantic_score,
    count_select_columns,
)

## Staged Debugging Guide (read this first)

This notebook uses **stages** to avoid over‑engineering before SQL is stable.

**How to run:**
1. Set `STAGE = 0` (in cell #6). Run cells in order.
2. Use the quick sanity check (cell #8). If **PRED is empty**, stop and inspect trace summaries.
3. Only move to the next stage after the current one is stable.

**Stages**
- **STAGE 0 — Minimal execution‑gated**: generate → extract SQL → execute → return first valid.
- **STAGE 1 — + Clamps**: strip ORDER/LIMIT unless requested; trim projections for list queries.
- **STAGE 2 — + Rerank & diversity**: tabular prompt + semantic score.
- **STAGE 3 — + Repair**: one‑shot repair using DB error hints.

**Where failures show up**
- *Rejected: not a clean SELECT* → model output is noisy; check raw candidates.
- *No clean candidates* → filters too strict or model not outputting SQL.
- *ERROR: ...* → SQL parsed but failed execution; fix joins or columns.

**Debug checklist**
- Ensure cell #6 ran after any edits (defines STAGE + helper funcs).
- Confirm `STAGE` in quick check output.
- If `PRED` is empty at STAGE 0, temporarily bypass filters and inspect raw output.

In [None]:
# 6) Helper: staged controls + candidate generation + error-aware repair
import re
import torch
import transformers
from transformers import StoppingCriteria, StoppingCriteriaList
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML

from nl2sql.llm import extract_first_select
from nl2sql.postprocess import guarded_postprocess
import nl2sql.agent_utils as agent_utils_mod  # for monkey-patching cleaner

transformers.logging.set_verbosity_error()
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# ----- Stage controls -----
STAGE = 1  # 0=minimal, 1=+clamps, 2=+rerank+tabular+sampling, 3=+repair
USE_CLAMPS = STAGE >= 1
USE_RERANK = STAGE >= 2
USE_TABULAR_PROMPT = STAGE >= 2
USE_SAMPLING = STAGE >= 2
USE_REPAIR = STAGE >= 3

# ----- Debug controls -----
DEBUG = True
DEBUG_RAW = False          # print raw candidates (noisy)
DEBUG_REJECT_SAMPLE = 2

# ---------- generation stop (first semicolon) ----------
class StopOnSemicolon(StoppingCriteria):
    def __init__(self, tokenizer):
        self.tok = tokenizer
        self.semi_id = tokenizer.encode(";", add_special_tokens=False)[-1]

    def __call__(self, input_ids, scores, **kwargs):
        return input_ids[0, -1].item() == self.semi_id

# ---------- keyword normaliser ----------
def _normalize_spaced_keywords(text: str) -> str:
    keywords = [
        "select", "from", "where", "group", "by", "order", "limit",
        "join", "inner", "left", "right", "on", "having", "distinct",
    ]
    for kw in keywords:
        pattern = r"" + r"\s*".join(list(kw)) + r""
        text = re.sub(pattern, kw.upper(), text, flags=re.I)
    return text

# ---------- shared core cleaner (trim prompt echo, relaxed) ----------
ECHO_CUTOFF_RE = re.compile(r"(?is)\b(show step|show output|output only|respond with|no markdown|no explanation|y/n)\b")

def strip_prompt_echo(sql: str) -> str:
    m = ECHO_CUTOFF_RE.search(sql or "")
    if not m:
        return sql
    return (sql[:m.start()]).strip()

def _clean_candidate_core(raw: str):
    """
    Returns (sql_or_none, reason)
    """
    if not raw:
        return None, "empty"

    raw = _normalize_spaced_keywords(raw)
    sql = extract_first_select(raw) or raw
    sql = sql.strip()

    lower = sql.lower()
    idx = lower.find("select")
    if idx == -1:
        return None, "no_select"
    sql = sql[idx:].strip()
    lower = sql.lower()

    # Trim common prompt-echo tails if present after SQL
    for marker in [
        "output only sql",
        "no explanation",
        "no markdown",
        "show output",
        "y/n",
    ]:
        pos = lower.find(marker)
        if pos != -1:
            sql = sql[:pos].strip()
            lower = sql.lower()

    # Cut at first semicolon to drop trailing chatter
    if ";" in sql:
        sql = sql.split(";", 1)[0].strip()
        lower = sql.lower()

    # Strip prompt-echo tails that appear after SQL
    sql = strip_prompt_echo(sql)
    lower = sql.lower()

    # Minimal junk filter (avoid rejecting useful SQL)
    bad_phrases = ("```",)
    if any(bp in lower for bp in bad_phrases):
        return None, "bad_phrase"

    if not lower.startswith("select"):
        return None, "no_select"

    # Require FROM and block FROM dual
    if not re.search(r"(?is)\bfrom\b", sql):
        return None, "no_from"
    if re.search(r"(?is)\bfrom\s+dual\b", sql):
        return None, "from_dual"
    if re.search(r"(?is)\bgroup\s+by\s+null\b", sql):
        return None, "group_by_null"

    return sql + ";", "ok"

# Local cleaner for ReAct (returns reason)
def clean_candidate(raw: str):
    return _clean_candidate_core(raw)  # (sql, reason)

# Monkey-patch agent_utils.clean_candidate so vanilla_candidate uses relaxed cleaner
def _clean_candidate_for_agent_utils(raw: str):
    sql, _ = _clean_candidate_core(raw)
    return sql  # Optional[str]

agent_utils_mod.clean_candidate = _clean_candidate_for_agent_utils

# ---------- Lightweight guardrails / clamps ----------
def _has_keyword(text: str, keywords) -> bool:
    lt = (text or "").lower()
    return any(k in lt for k in keywords)

def strip_order_by_if_not_requested(sql: str, nlq: str) -> str:
    if _has_keyword(nlq, ["order", "sort", "top", "first", "highest", "lowest", "descending", "ascending", "limit", "rank"]):
        return sql
    out = re.sub(r"(?is)\s+order\s+by\s+.*?(?=(limit|;|$))", "", sql)
    return out if out.endswith(";") else out + ";"

def trim_to_first_column(sql: str) -> str:
    parsed = sqlparse.parse(sql)
    if not parsed:
        return sql
    stmt = parsed[0]
    new_tokens, seen_select, trimmed = [], False, False
    for tok in stmt.tokens:
        if tok.ttype is DML and tok.value.upper() == "SELECT":
            seen_select = True
            new_tokens.append(tok)
            continue
        if seen_select and isinstance(tok, IdentifierList) and not trimmed:
            try:
                first_ident = next(tok.get_identifiers())
            except StopIteration:
                new_tokens.append(tok)
            else:
                new_tokens.append(IdentifierList([first_ident]))
                trimmed = True
            continue
        if seen_select and isinstance(tok, Identifier) and not trimmed:
            new_tokens.append(tok)
            trimmed = True
            continue
        new_tokens.append(tok)
    cleaned = "".join(str(t) for t in new_tokens).strip()
    return cleaned if cleaned.endswith(";") else cleaned + ";"

def strip_group_by_if_not_requested(sql: str, nlq: str) -> str:
    nl = (nlq or "").lower()
    if any(w in nl for w in ["per ", "by ", "each", "group"]):
        return sql
    if not any(w in nl for w in ["count", "how many", "number of"]):
        return sql
    out = re.sub(r"(?is)\s+group\s+by\s+[^;]+", "", sql)
    return out if out.endswith(";") else out + ";"

def ensure_group_key_in_select(sql: str, nlq: str) -> str:
    nl = (nlq or "").lower()
    if not ("per " in nl or "by " in nl):
        return sql

    s = sql.strip().rstrip(";")
    gb = re.search(r"(?is)group\s+by\s+([a-zA-Z0-9_\.]+)", s)
    if not gb:
        return sql
    group_key = gb.group(1)

    sel = re.search(r"(?is)^\s*select\s+(.*?)\s+from\s+", s)
    if not sel:
        return sql
    select_part = sel.group(1)

    if group_key.lower() in select_part.lower():
        return sql

    new_select = f"{group_key}, {select_part}"
    out = re.sub(r"(?is)^\s*select\s+(.*?)\s+from\s+", f"SELECT {new_select} FROM ", s, count=1)
    return out + ";"

def apply_clamps(sql: str, nlq: str) -> str:
    if not USE_CLAMPS:
        return sql
    sql = strip_order_by_if_not_requested(sql, nlq)
    sql = strip_group_by_if_not_requested(sql, nlq)
    if _has_keyword(nlq, ["which", "who", "what", "list", "show"]) and " and " not in (nlq or "").lower():
        sql = trim_to_first_column(sql)
    sql = ensure_group_key_in_select(sql, nlq)
    return sql

def canonicalize_table_casing(sql: str) -> str:
    if not sql:
        return sql
    def repl(m):
        tbl = m.group(2)
        canon = TABLES_LOWER.get(tbl.lower(), tbl)
        return m.group(1) + canon
    return re.sub(r"(?is)\b(from|join)\s+([a-zA-Z_][a-zA-Z0-9_]*)\b", repl, sql)

# ---------- Prompts ----------
def build_react_prompt(nlq: str, schema_text: str, history: list[dict], observation: str) -> str:
    history_text = "

".join(
        f"Thought/Action: {h.get('ta','')}
Observation: {h.get('obs','')}"
        for h in history
    ) or "None yet."
    return f"""
You are an expert MySQL analyst.

TASK:
- Write exactly ONE valid MySQL SELECT statement.
- Output only SQL (no explanation, no markdown).
- The output must include a FROM clause.
- Use only schema columns.
- Use ORDER BY/LIMIT only if explicitly asked.

Schema:
{schema_text}

Question:
{nlq}

Recent steps:
{history_text}

Last observation:
{observation}

Respond with only the final SQL statement.
""".strip()

def build_tabular_prompt(nlq: str, schema_text: str) -> str:
    return f"""
You are an expert SQL engineer. Think through tables and join keys, then output one SELECT.

Schema:
{schema_text}

Question: {nlq}

Output only the final SQL statement and nothing else.
""".strip()

def projection_guard(sql: str, nlq: str) -> str:
    return guarded_postprocess(sql, nlq)

# ---------- Candidate generation ----------
def generate_candidates(prompt: str, num: int = 1):
    do_sample = USE_SAMPLING
    if not do_sample:
        num = 1

    # force SQL start
    prompt = prompt.rstrip() + "
SELECT "
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    gen_kwargs = dict(max_new_tokens=128, do_sample=do_sample, num_return_sequences=num)
    if do_sample:
        gen_kwargs.update(dict(temperature=0.5, top_p=0.9))

    with torch.no_grad():
        out = model.generate(**inputs, **gen_kwargs)

    cands = []
    for i in range(num):
        gen_ids = out[i][inputs.input_ids.shape[-1]:]
        gen = tok.decode(gen_ids, skip_special_tokens=True)
        gen = _normalize_spaced_keywords(gen)
        cands.append("SELECT " + gen)
    return cands

# ---------- Error-aware repair ----------
def repair_sql(nlq: str, bad_sql: str, error_msg: str, schema_text: str):
    if not USE_REPAIR:
        return None, {"enabled": False}
    prompt = f"""
You are an expert MySQL engineer.

Schema:
{schema_text}

User question:
{nlq}

Invalid SQL:
{bad_sql}

Database error:
{error_msg}

Fix the SQL so it is valid MySQL and answers the question.
Output ONLY the corrected SELECT statement.
""".strip()
    fixes = generate_candidates(prompt, num=1)
    if not fixes:
        return None, {"enabled": True, "status": "no_fix_generated"}

    raw_fix = fixes[0]
    sql = extract_first_select(raw_fix) or raw_fix
    sql = sql.strip()
    if not sql.lower().startswith("select"):
        return None, {"enabled": True, "status": "fix_not_select", "raw_fix": raw_fix}

    sql = sql if sql.endswith(";") else sql + ";"
    sql = projection_guard(sql, nlq)
    sql = apply_clamps(sql, nlq)

    try:
        meta = runner.run(sql)
        return (sql if meta.success else None), {
            "enabled": True,
            "status": "exec_ok" if meta.success else "exec_fail",
            "raw_fix": raw_fix,
            "fixed_sql": sql,
            "exec_error": meta.error,
        }
    except Exception as e:
        return None, {
            "enabled": True,
            "status": "exception",
            "raw_fix": raw_fix,
            "fixed_sql": sql,
            "exc": str(e),
        }


In [None]:
# 7) ReAct definition (STAGE-gated)
# STAGE 0 uses minimal execution-gated decoding.
# STAGE >=1 enables the full ReAct-style pipeline.

# small structured logger for trace
def _log(history, **kwargs):
    def _trim(x, n=500):
        if x is None:
            return None
        s = str(x)
        return s if len(s) <= n else s[:n] + "…"
    history.append({k: _trim(v) for k, v in kwargs.items()})

from nl2sql.llm import extract_first_select


def _minimal_sql(raw: str):
    raw = _normalize_spaced_keywords(raw or "")
    sql = extract_first_select(raw)
    if not sql:
        return None
    sql = sql.split(";", 1)[0].strip() + ";"
    return sql


def _react_sql_minimal(
    nlq: str,
    schema_text: str,
    schema_summary,
    max_steps: int = 3,
    num_cands: int = 1,
    exemplars=None,
):
    history = []
    observation = "Start."

    for _ in range(max_steps):
        prompt = build_react_prompt(nlq, schema_text, history, observation)
        raw = generate_candidates(prompt, num=1)[0]

        sql = _minimal_sql(raw)
        if not sql:
            history.append({"ta": raw, "obs": "Rejected: no SELECT found"})
            observation = "No SELECT"
            continue

        sql = projection_guard(sql, nlq)
        try:
            meta = runner.run(sql)
            if meta.success:
                history.append({"ta": sql, "obs": "SUCCESS"})
                return sql, history
            history.append({"ta": sql, "obs": f"ERROR: {meta.error or 'exec fail'}"})
        except Exception as e:
            history.append({"ta": sql, "obs": f"ERROR: {e}"})

    history.append({"ta": "", "obs": "No valid SQL found"})
    return "", history


def _react_sql_full(
    nlq: str,
    schema_text: str,
    schema_summary,
    max_steps: int = 3,
    num_cands: int = 4,
    exemplars=None,
):
    """
    ReAct-inspired execution-guided agent:
    - multi-candidate generation
    - strict SELECT-only filtering
    - clamping + guarded postprocess
    - execution-based gating
    - semantic reranking
    - optional one-step repair
    - deterministic few-shot baseline fallback
    """
    history: list[dict] = []
    observation = "Start."
    final_sql = None

    for step in range(max_steps):
        prompt_main = build_react_prompt(nlq, schema_text, history, observation)
        prompt_tab = build_tabular_prompt(nlq, schema_text)

        raw_cands = []
        raw_cands += generate_candidates(prompt_main, num=max(1, num_cands // 2))
        raw_cands += generate_candidates(prompt_tab, num=num_cands - len(raw_cands))

        candidates = []
        for raw in raw_cands:
            sql, reason = clean_candidate(raw)
            if not sql:
                _log(history, step=step, phase="clean", raw=raw, reason=reason)
                continue
            sql_pp = projection_guard(sql, nlq)
            sql_pp = canonicalize_table_casing(sql_pp)
            sql_pp = apply_clamps(sql_pp, nlq)
            candidates.append((raw, sql, sql_pp))
            _log(history, step=step, phase="candidate", raw=raw, cleaned=sql, post=sql_pp)

        last_error = None
        best_score = float("-inf")
        best_sql = None

        for raw, sql_clean, sql_exec in candidates:
            try:
                meta = runner.run(sql_exec, capture_df=False)
                _log(history, step=step, phase="exec", sql=sql_exec, success=meta.success, error=meta.error)
                if not meta.success:
                    raise ValueError(meta.error or "exec fail")

                s_sem = semantic_score(nlq, sql_exec)
                if USE_RERANK and s_sem < 1.0:
                    _log(history, step=step, phase="accept_gate", sql=sql_exec, score=s_sem, decision="reject")
                    continue
                s_cols = count_select_columns(sql_exec)
                total_score = s_sem - 0.5 * s_cols

                if total_score > best_score:
                    best_score = total_score
                    best_sql = sql_exec

            except Exception as e:
                last_error = str(e)
                _log(history, step=step, phase="exec_exception", sql=sql_exec, exc=last_error)

                repaired, repinfo = repair_sql(nlq, sql_exec, last_error, schema_text)
                _log(history, step=step, phase="repair", bad_sql=sql_exec, error=last_error,
                     rep_status=repinfo.get("status"), raw_fix=repinfo.get("raw_fix"),
                     fixed_sql=repinfo.get("fixed_sql"), rep_exec_error=repinfo.get("exec_error") or repinfo.get("exc"))

                if repaired:
                    s_sem2 = semantic_score(nlq, repaired)
                    if USE_RERANK and s_sem2 < 1.0:
                        _log(history, step=step, phase="repair_reject", sql=repaired, score=s_sem2, decision="reject")
                    else:
                        best_sql = repaired
                        break
                continue

        if best_sql is None and last_error:
            kind = classify_error(last_error)
            hint = error_hint(kind, last_error)
            observation = f"{last_error} — {hint}"
            history.append({"ta": "", "obs": observation})
            continue

        if best_sql is None:
            observation = "All candidates failed execution"
            history.append({"ta": "", "obs": observation})
            continue

        observation = "SUCCESS"
        final_sql = best_sql
        history.append({"ta": best_sql, "obs": observation})
        break

    if final_sql:
        return final_sql, history

    fallback = vanilla_candidate(
        nlq=nlq,
        schema_summary=schema_summary,
        tok=tok,
        model=model,
        exemplars=exemplars or test_set[:3],
    )
    if fallback:
        history.append({"ta": fallback, "obs": "Baseline fallback after no successful candidate"})
        return fallback, history

    history.append({"ta": "", "obs": "No valid SQL found (including baseline fallback)"})
    return "", history


# Dispatch based on stage
if STAGE == 0:
    print("STAGE 0: using minimal execution-gated react_sql")
    react_sql = _react_sql_minimal
else:
    print("STAGE >=1: using full ReAct-style react_sql")
    react_sql = _react_sql_full


In [None]:
# 8) Quick sanity check on a few items
schema_text = SCHEMA_SUMMARY
for sample in test_set[:5]:
    nlq = sample["nlq"]
    gold = sample["sql"]
    pred, trace = react_sql(
        nlq=nlq,
        schema_text=schema_text,
        schema_summary=SCHEMA_SUMMARY,
        max_steps=3,
        num_cands=4,
        exemplars=test_set[:3],
    )
    print("NLQ:", nlq)
    print("PRED:", pred)
    print("GOLD:", gold)
    print("TRACE LEN:", len(trace))
    print("-" * 80)

### Stage 3 Interpretation (29 Jan 2026)

- **Valid SQL stability:** Stage 3 generally returns executable SQL; remaining issues are **projection bloat** (extra columns), and **unnecessary ORDER BY/GROUP BY**.
- **Metric impact:** These are EM regressions more than EX regressions. Use clamps + final normalization to keep outputs canonical.
- **Trace logging upgrade:** The ReAct loop now logs **raw → cleaned → post‑clamp → exec error → repair attempt**, so failures can be attributed to generation vs cleaning vs execution vs repair.


In [None]:
# 9) Full ReAct-style evaluation (VA/EX/EM) over test_set
from nl2sql.eval import execution_accuracy
results = []
LIMIT = None  # set to e.g. 20 for a quick slice
items = test_set[:LIMIT] if LIMIT else test_set
schema_text = SCHEMA_SUMMARY

for i, sample in enumerate(items, start=1):
    nlq = sample["nlq"]
    gold_sql = sample["sql"]

    pred_sql, trace = react_sql(
        nlq=nlq,
        schema_text=schema_text,
        schema_summary=SCHEMA_SUMMARY,
        max_steps=3,
        num_cands=4,
        exemplars=test_set[:3],
    )

    pred_clean = pred_sql.strip().rstrip(";").lower()
    gold_clean = gold_sql.strip().rstrip(";").lower()
    em = int(pred_clean == gold_clean)
    va = ex = 0

    try:
        meta = runner.run(pred_sql)
        va = int(meta.success)
        if meta.success:
            ex_ok, _, _ = execution_accuracy(engine=engine, pred_sql=pred_sql, gold_sql=gold_sql)
            ex = int(ex_ok)
    except Exception:
        va = 0
        ex = 0

    results.append({
        "nlq": nlq,
        "gold_sql": gold_sql,
        "pred_sql": pred_sql,
        "va": va,
        "em": em,
        "ex": ex,
        "trace": trace,
    })

    if i % 10 == 0:
        print(f"Processed {i}/{len(items)}")

va_rate = sum(r["va"] for r in results) / len(results)
ex_rate = sum(r["ex"] for r in results) / len(results)
em_rate = sum(r["em"] for r in results) / len(results)
print("ReAct VA:", va_rate, "EX:", ex_rate, "EM:", em_rate)

Path("results/agent").mkdir(parents=True, exist_ok=True)
save_path = Path("results/agent/results_react_200.json")
save_path.write_text(
    json.dumps({
        "va_rate": va_rate,
        "ex_rate": ex_rate,
        "em_rate": em_rate,
        "items": results,
    }, ensure_ascii=False, indent=2),
    encoding="utf-8",
)
print("Saved to", save_path)