# Agentic Evaluation (ReAct-style)

Purpose: run a **ReAct‑inspired, execution‑guided NL→SQL pipeline** on ClassicModels,
report **VA / EX / EM / TS**, and keep everything reproducible.

Run order:
1) Environment + DB connection
2) Schema + test set
3) Model load
4) ReAct helpers + prompts
5) Quick sanity check
6) Full eval (VA/EX/EM/TS)


Docs used:
- HF Transformers quantization (4‑bit NF4)
- PEFT / QLoRA
- Cloud SQL Connector + SQLAlchemy creator
- ReAct (reason → act → observe → revise)


## Setup (run first, then restart)
In a fresh Colab GPU runtime, run this one cell to clean preinstalls and pin the CUDA 12.1 torch/bitsandbytes/triton stack. When it finishes, **Runtime → Restart runtime**, then run the rest of the notebook from the clone cell onward without more restarts.

**Docs (setup):** HF Transformers quantization + BitsAndBytes (4-bit) https://huggingface.co/docs/transformers/main_classes/quantization, bnb https://github.com/TimDettmers/bitsandbytes.

In [None]:

%%bash
set -e
export PIP_DEFAULT_TIMEOUT=120

# Clean conflicting preinstalls
pip uninstall -y torch torchvision torchaudio bitsandbytes triton transformers accelerate peft trl datasets numpy pandas fsspec requests google-auth || true

# Base deps
pip install -q --no-cache-dir --force-reinstall   numpy==1.26.4 pandas==2.2.1 fsspec==2024.5.0 requests==2.31.0 google-auth==2.43.0

# Torch + CUDA 12.1
pip install -q --no-cache-dir --force-reinstall   torch==2.3.1+cu121 torchvision==0.18.1+cu121 torchaudio==2.3.1+cu121   --index-url https://download.pytorch.org/whl/cu121

# bitsandbytes + triton + HF stack
pip install -q --no-cache-dir --force-reinstall   bitsandbytes==0.43.3 triton==2.3.1   transformers==4.44.2 accelerate==0.33.0 peft==0.17.0 trl==0.9.6 datasets==2.20.0

echo "Setup complete. Restart runtime once, then run the rest of the notebook top-to-bottom."


Model load: HF 4-bit NF4 + BitsAndBytes; deterministic decoding. If adapters exist, we load them.

In [None]:
# 0) Clone repo (Colab) + install deps
import os
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    if not os.path.exists('/content/NLtoSQL'):
        !git clone https://github.com/MacKenzieOBrian/NLtoSQL.git /content/NLtoSQL
    %cd /content/NLtoSQL
    !pip -q install -r requirements.txt
    import torch, transformers, accelerate, peft
    print('torch', torch.__version__, 'cuda', torch.cuda.is_available())
else:
    print('Not in Colab; using existing workspace')


Prompt/eval: build prompts (system+schema+k exemplars), generate SQL, postprocess, and compute VA/EX/EM.

**Ref:** Colab clone/install pattern; keeps notebooks thin and code in `nl2sql/`. Hugging Face/Colab standard workflow.

### Reference notes (what this builds on)
- DB access: Cloud SQL Connector + SQLAlchemy creator
- Schema/prompting: schema‑grounded prompting surveys
- Model load: HF 4‑bit NF4 + PEFT
- Agent loop: ReAct‑style execution feedback
- Eval: VA/EX/EM + TS harness


## Optional: use gcloud ADC (without a key)

**Ref:** GCP ADC flow (docs: https://cloud.google.com/docs/authentication/provide-credentials-adc). Optional fallback if no service account JSON.

In [None]:
# Run this only if you prefer gcloud-based ADC (no JSON key)
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    %pip install -q --upgrade google-auth google-auth-oauthlib
    !gcloud auth application-default login
else:
    print("Not in Colab; skip gcloud auth.")


**Ref:** Pinned CUDA12.1 torch/bitsandbytes/triton stack per HF/BnB guidance for 4-bit loads on Colab GPUs.

**Ref:** Cloud SQL Connector + SQLAlchemy creator (GCP MySQL docs: https://cloud.google.com/sql/docs/mysql/connect-run) for secure ClassicModels access.

**Docs (auth/DB):** Cloud SQL connector pattern https://cloud.google.com/sql/docs/mysql/connect-run; SQLAlchemy creator hook https://docs.sqlalchemy.org/en/20/core/engines.html#custom-dbapi-connect.

In [None]:

# 1) Environment + DB
import os
from getpass import getpass
from pathlib import Path

from google.cloud.sql.connector import Connector
from google.oauth2.service_account import Credentials
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine

# Expected env vars (set these in a Colab cell):
# GOOGLE_APPLICATION_CREDENTIALS=/content/sa.json
# INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME
INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME") or "classicmodels"
GOOGLE_CREDS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

creds = None
if GOOGLE_CREDS:
    creds = Credentials.from_service_account_file(GOOGLE_CREDS)
    print(f"Using service account from {GOOGLE_CREDS}")
else:
    print("Using default ADC (gcloud auth or Colab auth). If this fails, set GOOGLE_APPLICATION_CREDENTIALS.")

connector = Connector(credentials=creds)

def getconn():
    return connector.connect(
        INSTANCE_CONNECTION_NAME,
        "pymysql",
        user=DB_USER,
        password=DB_PASS,
        db=DB_NAME,
    )

engine: Engine = create_engine(
    "mysql+pymysql://",
    creator=getconn,
    future=True,
)

with engine.connect() as conn:
    conn.execute(text("SELECT 1"))
print("DB connection OK")


In [None]:
# 1b) Engine factory for TS (multiple DB names)
from sqlalchemy import create_engine

def make_engine(db_name: str) -> Engine:
    """Create a new engine for a specific DB name using the same Cloud SQL connector."""
    def getconn_for_db():
        return connector.connect(
            INSTANCE_CONNECTION_NAME,
            "pymysql",
            user=DB_USER,
            password=DB_PASS,
            db=db_name,
        )
    return create_engine("mysql+pymysql://", creator=getconn_for_db, future=True)


**Ref:** Schema helper in `nl2sql.schema`; schema-grounded prompting per NL→SQL survey (https://arxiv.org/abs/2410.06011).

**Docs (schema prompts):** NL→SQL schema-grounded prompting survey https://arxiv.org/abs/2410.06011; Spider-style listings.

In [None]:
# 2) Load schema summary + test set (small slice for now)
import json
from nl2sql.schema import build_schema_summary

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)

test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
# default to a small slice while debugging
test_set = full_set[:5]
print("Demo items:", len(test_set))
# For full run, switch to: test_set = full_set; print("Test items:", len(test_set))

TABLES = {line.split('(', 1)[0].strip() for line in SCHEMA_SUMMARY.splitlines() if '(' in line}
TABLES_LOWER = {t.lower(): t for t in TABLES}


**Ref:** HF Transformers 4-bit NF4 + BitsAndBytes (quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization); adapters via PEFT.

**Docs (model load):** HF 4-bit NF4 quantization https://huggingface.co/docs/transformers/main_classes/quantization; PEFT/QLoRA https://huggingface.co/docs/peft/.

In [None]:

# 3) Load model (base or QLoRA adapters)
import os
from getpass import getpass
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = os.getenv("ADAPTER_PATH") or "results/adapters/qlora_classicmodels"  # set to None to use base model

HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    HF_TOKEN = getpass("Enter HF_TOKEN (https://huggingface.co/settings/tokens): ").strip()

cc_major, cc_minor = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else (0, 0)
use_bf16 = cc_major >= 8
compute_dtype = torch.bfloat16 if use_bf16 else torch.float16
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("Using bf16:", use_bf16)
print("Adapter path:", ADAPTER_PATH)

# Tokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Quantized base model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
    device_map={"": 0} if torch.cuda.is_available() else None,
    token=HF_TOKEN,
)
base_model.generation_config.do_sample = False
base_model.generation_config.temperature = 1.0
base_model.generation_config.top_p = 1.0

# Load adapters if present locally; otherwise use base model
adapter_dir = Path(ADAPTER_PATH) if ADAPTER_PATH else None
if adapter_dir and adapter_dir.exists():
    model = PeftModel.from_pretrained(base_model, adapter_dir, token=HF_TOKEN)
    print("Loaded adapters from", adapter_dir)
else:
    print("Adapter path missing; using base model only. Set ADAPTER_PATH to your local adapter folder or upload it to Colab.")
    model = base_model


## Optional adapter sanity check (run before ReAct)
Quick check to see if the loaded model/adapters produce valid SQL on a tiny slice. Uses the prompt harness (k=0/k=3) and executes the SQL to report VA/EX.

**Docs (prompt/eval):** ICL patterns https://arxiv.org/abs/2005.14165; execution-based metrics (VA/EX) https://aclanthology.org/2020.emnlp-main.29/.

In [None]:
from nl2sql.prompting import make_few_shot_messages
from nl2sql.llm import extract_first_select
from nl2sql.postprocess import guarded_postprocess
from nl2sql.query_runner import QueryRunner
from nl2sql.eval import execution_accuracy

runner_check = QueryRunner(engine)
# reuse existing test_set (default small slice); pick 3 exemplars
exemplars = test_set[:3]

def run_quick_check(k: int = 0, limit: int = 3):
    print(f"Quick check k={k}")
    for sample in test_set[:limit]:
        shots = exemplars if k > 0 else []
        msgs = make_few_shot_messages(
            schema=SCHEMA_SUMMARY,
            exemplars=shots,
            nlq=sample['nlq'],
        )
        prompt_preview = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        inputs = tok(prompt_preview, return_tensors="pt").to(model.device)
        out = model.generate(**inputs, max_new_tokens=256, do_sample=False)

        # strip the prompt before decoding the generation
        gen_ids = out[0][inputs.input_ids.shape[-1]:]
        text = tok.decode(gen_ids, skip_special_tokens=True)

        raw_sql = extract_first_select(text) or text
        sql = guarded_postprocess(raw_sql, sample['nlq'])

        meta = runner_check.run(sql, capture_df=False)
        va = meta.success
        ex_ok, _, _ = execution_accuracy(engine=engine, pred_sql=sql, gold_sql=sample['sql'])
        print(f"Q: {sample['nlq']}
SQL: {sql}
VA: {va} EX: {ex_ok}
")

run_quick_check(k=0)
run_quick_check(k=3)


**Ref:** ReAct pattern (Yao et al. 2023: https://arxiv.org/abs/2210.03629) adapted for NL→SQL with `QueryRunner` as the Act step.

**Docs (ReAct):** ReAct loop (Yao et al. 2023) https://arxiv.org/abs/2210.03629; safe Act via SELECT-only executor.

In [None]:

# Helper imports (ReAct helpers live in nl2sql/agent_utils)
from nl2sql.agent_utils import (
    clean_candidate,
    build_tabular_prompt,
    vanilla_candidate,
    classify_error,
    error_hint,
    semantic_score,
    count_select_columns,
)


### Agent status (for dissertation)
Current loop: **generate → clean → post‑process → execute → intent‑gate → score → (repair)**.
Observation is fed into the next step (ReAct pattern).
TS is included as suite‑based semantic evaluation.


**Ref:** Repo eval (`nl2sql.eval`) for VA/EX/EM; execution-based metrics align with Ojuri et al. 2025 and EMNLP’20 TS.

**Docs (prompt/eval):** ICL patterns https://arxiv.org/abs/2005.14165; execution-based metrics (VA/EX) https://aclanthology.org/2020.emnlp-main.29/.

## ReAct execution-guided pipeline (best version so far)
These cells mirror the committed helper layer (`nl2sql/agent_utils.py`) and set up the current execution-guided reranker + evaluation harness.

## Reference Map (Code ↔ Literature)
- Execution guidance & repair → ExCoT [2], ReAct [16], TS [18]
- Constrained decoding/output hygiene → PICARD [13], surveys [8], [9]
- Projection contract → survey [8], BigBench Text‑to‑SQL [1]
- Intent constraints → ExCoT [2], survey [8], benchmark eval [20]
- Schema‑subset prompting → RESDSQL [17], surveys [8], [9]


In [None]:
# 4) Schema summary + test set + QueryRunner
import json
from pathlib import Path
from nl2sql.schema import build_schema_summary
from nl2sql.query_runner import QueryRunner

DB_NAME = os.getenv("DB_NAME") or "classicmodels"

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME)
test_path = Path("data/classicmodels_test_200.json")
full_set = json.loads(test_path.read_text(encoding="utf-8"))
test_set = full_set  # change to full_set[:20] when debugging

print("Loaded test set size:", len(test_set))
runner = QueryRunner(engine)

In [None]:
# 5) Agent utilities: semantic reranker, baseline candidate, error taxonomy
from nl2sql.agent_utils import (
    clean_candidate,
    build_tabular_prompt,
    vanilla_candidate,
    classify_error,
    error_hint,
    semantic_score,
    count_select_columns,
)

## 6. Helper Layer: Staged Controls, Candidate Generation, and Error-Aware Repair

This cell implements the *infrastructure layer* of the agentic text-to-SQL system.  
Rather than directly generating and executing SQL, modern LLM-based systems require
additional tooling to ensure **syntactic validity**, **intent alignment**, and
**robust recovery from execution errors**.

The design of this helper layer is motivated by findings from recent text-to-SQL
benchmarks and surveys, which show that a large proportion of failures arise from
formatting errors, over-generation, or schema mismatches rather than incorrect
semantic reasoning [1], [8], [9].

---

### 6.1 Staged Feature Controls

We introduce a staged configuration mechanism (`STAGE = 0–3`) to progressively enable
agent capabilities:

- **Stage 0**: Minimal execution-gated decoding
- **Stage 1**: Lightweight post-processing and clamping
- **Stage 2**: Multi-candidate generation and reranking
- **Stage 3**: Execution-feedback-driven repair

This staged design enables controlled ablation studies and mirrors experimental
methodologies used in large-scale text-to-SQL evaluations [1], [18], [20].
It also allows us to isolate the impact of execution feedback and repair mechanisms,
which are known to significantly improve execution accuracy [2].

---

### 6.2 Controlled Generation via Stopping Criteria

LLMs frequently produce trailing explanations, multiple SQL statements, or partial
queries. To mitigate this, we implement a custom decoding constraint that halts
generation at the first semicolon.

This approach uses the `StoppingCriteria` API from the Hugging Face Transformers
library and provides a lightweight alternative to grammar-constrained decoding
methods such as PICARD [13].

The goal is to enforce a *single-statement SELECT-only output*, improving the
**Valid SQL (VA)** metric without modifying model weights.

---

### 6.3 SQL Normalisation and Prompt-Echo Removal

Empirical inspection of model outputs reveals frequent contamination from prompt
instructions (e.g. "Output SQL only", "No explanation"). These artifacts negatively
impact syntactic validity and downstream execution.

To address this, we apply deterministic cleaning steps:
- Extract the first valid `SELECT` clause
- Remove prompt-echo phrases
- Enforce the presence of a `FROM` clause
- Reject dangling or incomplete SQL fragments

Such post-processing is standard practice in text-to-SQL systems and is necessary
to fairly evaluate execution accuracy independently of formatting noise [1], [18].

---

### 6.4 Intent-Aware Clamping and Guardrails

LLMs tend to over-generate SQL clauses (e.g. unnecessary `ORDER BY` or `GROUP BY`).
We introduce lightweight *intent-aware clamps* that:
- Remove ordering unless explicitly requested
- Ensure grouping keys appear in the SELECT projection
- Restrict projections for simple listing queries

These heuristics align with observations in recent surveys that LLMs struggle with
projection minimality and clause relevance [8], [9].

---

### 6.5 Execution Feedback and Error-Aware Repair

A key contribution of this work is the integration of execution feedback as a
first-class signal.

When a generated query fails during execution, the database error message is fed
back into the model via a repair prompt. This allows the agent to generate a corrected
query grounded in the actual database behaviour.

This mechanism is inspired by:
- ReAct-style action–observation loops [16]
- Execution-guided reasoning frameworks such as ExCoT [2]
- Agent-based text-to-SQL systems integrating environment feedback [10], [21]

Rather than treating execution failure as terminal, the agent uses it as a learning
signal, significantly improving robustness on complex aggregation queries.


In [None]:
# 6) Helper: staged controls + candidate generation + error-aware repair
import re
import torch
import transformers
from transformers import StoppingCriteria, StoppingCriteriaList
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML

from nl2sql.llm import extract_first_select
from nl2sql.postprocess import guarded_postprocess
import nl2sql.agent_utils as agent_utils_mod  # for monkey-patching cleaner
from nl2sql.agent_utils import build_schema_subset, enforce_projection_contract, intent_constraints

transformers.logging.set_verbosity_error()
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# --- Simple ReAct config (replace STAGE/USE_* gating) ---
CFG = {
    "max_steps": 3,
    "num_cands": 6,
    "do_sample": True,
    "temperature": 0.5,
    "top_p": 0.9,
    "max_new_tokens": 128,
    "enable_repair": True,
    "use_tabular_prompt": True,
    "use_projection_contract": True,
    "use_clamps": True,
    "use_schema_subset": True,
}

# ----- Debug controls -----
DEBUG = True
DEBUG_RAW = False
DEBUG_REJECT_SAMPLE = 2

# ---------- generation stop (first semicolon) ----------
class StopOnSemicolon(StoppingCriteria):
    def __init__(self, tokenizer):
        self.semi_id = tokenizer.encode(";", add_special_tokens=False)[-1]

    def __call__(self, input_ids, scores, **kwargs):
        return input_ids[0, -1].item() == self.semi_id

# ---------- keyword normaliser ----------
def _normalize_spaced_keywords(text: str) -> str:
    keywords = [
        "select", "from", "where", "group", "by", "order", "limit",
        "join", "inner", "left", "right", "on", "having", "distinct",
    ]
    for kw in keywords:
        pattern = r"\b" + r"\s*".join(list(kw)) + r"\b"
        text = re.sub(pattern, kw.upper(), text, flags=re.I)
    return text

# ---------- shared core cleaner (trim prompt echo, relaxed) ----------
ECHO_CUTOFF_RE = re.compile(
    r"(?is)\b(show step|show output|output only|respond with|no markdown|no explanation|y/n|outputformatting|output formatting)\b"
)
DANGLING_RE = re.compile(r"(?is)\b(group\s+by|order\s+by|join|on|where|having|limit|offset)\s*;?\s*$")

def strip_prompt_echo(sql: str) -> str:
    m = ECHO_CUTOFF_RE.search(sql or "")
    if not m:
        return sql
    return (sql[:m.start()]).strip()


def _clean_candidate_core(raw: str):
    """
    Returns (sql_or_none, reason)
    """
    if not raw:
        return None, "empty"

    raw = _normalize_spaced_keywords(raw)
    sql = extract_first_select(raw) or raw
    sql = sql.strip()

    lower = sql.lower()
    idx = lower.find("select")
    if idx == -1:
        return None, "no_select"
    sql = sql[idx:].strip()
    lower = sql.lower()

    for marker in [
        "output only sql",
        "no explanation",
        "no markdown",
        "show output",
        "y/n",
        "outputformatting",
        "output formatting",
    ]:
        pos = lower.find(marker)
        if pos != -1:
            sql = sql[:pos].strip()
            lower = sql.lower()

    if ";" in sql:
        sql = sql.split(";", 1)[0].strip()
        lower = sql.lower()

    sql = strip_prompt_echo(sql)
    lower = sql.lower()

    if "```" in lower or "..." in lower or re.search(r"(?i)output", sql):
        return None, "bad_phrase"

    if not lower.startswith("select"):
        return None, "no_select"

    sql = re.sub(r"(?is)^\s*select\s+select", "SELECT", sql).strip()
    lower = sql.lower()

    if not re.search(r"(?is)\bfrom\b", sql):
        return None, "no_from"
    if re.search(r"(?is)\bfrom\s+dual\b", sql):
        return None, "from_dual"
    if re.search(r"(?is)\bgroup\s+by\s+null\b", sql):
        return None, "group_by_null"
    if DANGLING_RE.search(sql):
        return None, "dangling_clause"

    return sql + ";", "ok"

# Local cleaner for ReAct (returns reason)
def clean_candidate(raw: str):
    return _clean_candidate_core(raw)

# Monkey-patch agent_utils.clean_candidate so vanilla_candidate uses relaxed cleaner
def _clean_candidate_for_agent_utils(raw: str):
    sql, _ = _clean_candidate_core(raw)
    return sql

agent_utils_mod.clean_candidate = _clean_candidate_for_agent_utils

# ---------- Lightweight guardrails / clamps ----------
def _has_keyword(text: str, keywords) -> bool:
    lt = (text or "").lower()
    return any(k in lt for k in keywords)


def strip_order_by_if_not_requested(sql: str, nlq: str) -> str:
    if _has_keyword(nlq, ["sort", "sorted", "order by", "top", "first", "highest", "lowest", "descending", "ascending", "limit", "rank"]):
        return sql
    out = re.sub(r"(?is)\s+order\s+by\s+.*?(?=(\blimit\b|;|$))", "", sql)
    return out if out.endswith(";") else out + ";"


def trim_to_first_column(sql: str) -> str:
    parsed = sqlparse.parse(sql)
    if not parsed:
        return sql
    stmt = parsed[0]
    new_tokens, seen_select, trimmed = [], False, False
    for tok in stmt.tokens:
        if tok.ttype is DML and tok.value.upper() == "SELECT":
            seen_select = True
            new_tokens.append(tok)
            continue
        if seen_select and isinstance(tok, IdentifierList) and not trimmed:
            try:
                first_ident = next(tok.get_identifiers())
            except StopIteration:
                new_tokens.append(tok)
            else:
                new_tokens.append(IdentifierList([first_ident]))
                trimmed = True
            continue
        if seen_select and isinstance(tok, Identifier) and not trimmed:
            new_tokens.append(tok)
            trimmed = True
            continue
        new_tokens.append(tok)
    cleaned = "".join(str(t) for t in new_tokens).strip()
    return cleaned if cleaned.endswith(";") else cleaned + ";"


def strip_group_by_if_not_requested(sql: str, nlq: str) -> str:
    nl = (nlq or "").lower()
    if any(w in nl for w in ["per ", "by ", "each", "group"]):
        return sql
    if not any(w in nl for w in ["count", "how many", "number of"]):
        return sql
    out = re.sub(r"(?is)\s+group\s+by\s+[^;]+", "", sql)
    return out if out.endswith(";") else out + ";"


def ensure_group_key_in_select(sql: str, nlq: str) -> str:
    nl = (nlq or "").lower()
    if not ("per " in nl or "by " in nl):
        return sql

    s = sql.strip().rstrip(";")
    gb = re.search(r"(?is)\bgroup\s+by\s+([a-zA-Z0-9_\.]+)", s)
    if not gb:
        return sql
    group_key = gb.group(1)

    sel = re.search(r"(?is)^\s*select\s+(.*?)\s+from\s+", s)
    if not sel:
        return sql
    select_part = sel.group(1)

    if group_key.lower() in select_part.lower():
        return sql

    new_select = f"{group_key}, {select_part}"
    out = re.sub(r"(?is)^\s*select\s+(.*?)\s+from\s+", f"SELECT {new_select} FROM ", s, count=1)
    return out + ";"


def add_group_by_if_requested

# explicit field list heuristic for clamp decisions
def _nlq_has_explicit_fields(nlq: str) -> bool:
    nl = (nlq or "").lower()
    if "," not in nl and " and " not in nl and " with " not in nl:
        return False
    field_hints = ["msrp", "product code", "product name", "product line", "order number",
                   "customer name", "credit limit", "city", "country", "phone"]
    return any(h in nl for h in field_hints)
(sql: str, nlq: str) -> str:
    if not any(k in (nlq or "").lower() for k in [" per ", " by ", " each "]):
        return sql
    if re.search(r"(?is)group\s+by", sql or ""):
        return sql
    if "order number" in (nlq or "").lower() or "per order" in (nlq or "").lower():
        key = "orderNumber"
    else:
        return sql
    s = sql.strip().rstrip(";")
    return s + f" GROUP BY {key};"


def apply_clamps(sql: str, nlq: str) -> str:
    if not CFG["use_clamps"]:
        return sql
    sql = strip_order_by_if_not_requested(sql, nlq)
    sql = strip_group_by_if_not_requested(sql, nlq)
    # Only trim projection when NLQ does not explicitly list fields.
    if _has_keyword(nlq, ["which", "who", "what", "list", "show"]) and not _nlq_has_explicit_fields(nlq):
        sql = trim_to_first_column(sql)
    sql = ensure_group_key_in_select(sql, nlq)
    sql = add_group_by_if_requested(sql, nlq)
    return sql


def canonicalize_table_casing(sql: str) -> str:
    if not sql:
        return sql
    def repl(m):
        tbl = m.group(2)
        canon = TABLES_LOWER.get(tbl.lower(), tbl)
        return f"{m.group(1)} {canon}"
    return re.sub(r"(?is)\b(from|join)\s+([a-zA-Z_][a-zA-Z0-9_]*)\b", repl, sql)

# ---------- Prompts ----------
def build_react_prompt(nlq: str, schema_text: str, history: list[dict], observation: str) -> str:
    schema_view = build_schema_subset(schema_text, nlq) if CFG["use_schema_subset"] else schema_text
    history_text = "

".join(
        f"Thought/Action: {h.get('ta','')}
Observation: {h.get('obs','')}"
        for h in history
    ) or "None yet."
    return f"""
You are an expert MySQL analyst.

TASK:
- Write exactly ONE valid MySQL SELECT statement.
- Output only SQL (no explanation, no markdown).
- The output must include a FROM clause.
- Use only schema columns.
- Use ORDER BY/LIMIT only if explicitly asked.

Schema:
{schema_view}

Question:
{nlq}

Recent steps:
{history_text}

Last observation:
{observation}

Respond with only the final SQL statement.
""".strip()


def build_tabular_prompt(nlq: str, schema_text: str) -> str:
    schema_view = build_schema_subset(schema_text, nlq) if CFG["use_schema_subset"] else schema_text
    return f"""
You are an expert SQL engineer. Think through tables and join keys, then output one SELECT.

Schema:
{schema_view}

Question: {nlq}

Output only the final SQL statement and nothing else.
""".strip()


def projection_guard(sql: str, nlq: str) -> str:
    return guarded_postprocess(sql, nlq)

# ---------- Candidate generation ----------
def generate_candidates(prompt: str, num: int = 1, do_sample: bool | None = None):
    if do_sample is None:
        do_sample = CFG["do_sample"]

    if not do_sample:
        num = 1

    prompt = prompt.rstrip() + "
SELECT "
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    gen_kwargs = dict(
        max_new_tokens=CFG["max_new_tokens"],
        do_sample=do_sample,
        num_return_sequences=num,
        stopping_criteria=StoppingCriteriaList([StopOnSemicolon(tok)]),
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.pad_token_id,
    )
    if do_sample:
        gen_kwargs.update(dict(temperature=CFG["temperature"], top_p=CFG["top_p"]))

    with torch.no_grad():
        out = model.generate(**inputs, **gen_kwargs)

    cands = []
    for i in range(num):
        gen_ids = out[i][inputs.input_ids.shape[-1]:]
        gen = tok.decode(gen_ids, skip_special_tokens=True)
        gen = _normalize_spaced_keywords(gen)
        gen = re.sub(r"(?is)^\s*select\s*", "", gen)
        cands.append("SELECT " + gen)
    return cands


def extract_best_select(text: str) -> str | None:
    text = _normalize_spaced_keywords(text or "")
    parts = [p.strip() for p in text.split(";") if p.strip()]
    for p in parts:
        sel = extract_first_select(p) or p
        sel = sel.strip()
        if re.match(r"(?is)^select", sel) and re.search(r"(?is)from", sel):
            if re.match(r"(?is)^select\s+from", sel):
                continue
            return sel + ";"
    return None

# ---------- Error-aware repair ----------
def repair_sql(nlq: str, bad_sql: str, error_msg: str, schema_text: str):
    if not CFG["enable_repair"]:
        return None, {"enabled": False}
    prompt = f"""
You are an expert MySQL engineer.

Schema:
{schema_text}

User question:
{nlq}

Invalid SQL:
{bad_sql}

Database error:
{error_msg}

Fix the SQL so it is valid MySQL and answers the question.
Output ONLY the corrected SELECT statement.
""".strip()

    fixes = generate_candidates(prompt, num=4)
    if not fixes:
        return None, {"enabled": True, "status": "no_fix_generated"}

    for raw_fix in fixes:
        cand = extract_best_select(raw_fix)
        if not cand:
            continue
        sql = cand if cand.endswith(";") else cand + ";"
        sql = projection_guard(sql, nlq)
        sql = canonicalize_table_casing(sql)
        if CFG["use_projection_contract"]:
            sql = enforce_projection_contract(sql, nlq)
        sql = apply_clamps(sql, nlq)
        try:
            meta = runner.run(sql, capture_df=False)
            if meta.success:
                return sql, {
                    "enabled": True,
                    "status": "exec_ok",
                    "raw_fix": raw_fix,
                    "fixed_sql": sql,
                    "exec_error": meta.error,
                }
            last_err = meta.error
        except Exception as e:
            last_err = str(e)

        last_info = {
            "enabled": True,
            "status": "exec_fail",
            "raw_fix": raw_fix,
            "fixed_sql": sql,
            "exec_error": last_err,
        }

    return None, last_info if 'last_info' in locals() else {"enabled": True, "status": "no_valid_fix"}


In [None]:
# 7) Simple full ReAct loop (explainable + observation feedback)

def _log(history, **kwargs):
    """Small structured trace for dissertation analysis."""
    def _trim(x, n=500):
        if x is None:
            return None
        s = str(x)
        return s if len(s) <= n else s[:n] + "…"
    history.append({k: _trim(v) for k, v in kwargs.items()})


def postprocess_sql(sql: str, nlq: str) -> str:
    """
    Deterministic control layer to enforce output shape and reduce EX failures.
    """
    sql = projection_guard(sql, nlq)

    if CFG["use_projection_contract"]:
        sql = enforce_projection_contract(sql, nlq)

    sql = canonicalize_table_casing(sql)

    if CFG["use_clamps"]:
        sql = apply_clamps(sql, nlq)

    return sql


def make_prompts(nlq: str, schema_text: str, history: list[dict], observation: str) -> list[str]:
    prompts = [build_react_prompt(nlq, schema_text, history, observation)]
    if CFG["use_tabular_prompt"]:
        prompts.append(build_tabular_prompt(nlq, schema_text))
    return prompts


def evaluate_candidate(nlq: str, raw: str):
    """
    Turns raw text into a scored executable SQL (or returns why it failed).
    """
    sql, reason = clean_candidate(raw)
    if not sql:
        return None, {"phase": "clean_reject", "reason": reason, "raw": raw}

    sql = postprocess_sql(sql, nlq)

    # execution gate
    meta = runner.run(sql, capture_df=False)
    if not meta.success:
        return None, {"phase": "exec_fail", "sql": sql, "error": meta.error}

    # intent gate
    ok, why = intent_constraints(nlq, sql)
    if not ok:
        return None, {"phase": "intent_reject", "sql": sql, "reason": why}

    # simple score (explainable)
    s_sem = semantic_score(nlq, sql)
    s_cols = count_select_columns(sql)
    s_extra = score_sql(nlq, sql)
    score = s_sem - 0.5 * s_cols + s_extra

    return (sql, score), {"phase": "accept", "sql": sql, "score": score, "sem": s_sem, "cols": s_cols, "extra": s_extra}


def react_sql(nlq: str, schema_text: str, schema_summary=None, exemplars=None):
    """
    Full ReAct-style loop:
    - generate candidates
    - clean + postprocess
    - execute (observation)
    - revise (prompt includes observation)
    - optional repair on errors
    """
    history = []
    observation = "Start."
    last_failed_sql = None
    last_error = None

    for step in range(CFG["max_steps"]):
        prompts = make_prompts(nlq, schema_text, history, observation)

        per_prompt = max(1, CFG["num_cands"] // len(prompts))
        raw_cands = []
        for p in prompts:
            raw_cands += generate_candidates(p, num=per_prompt, do_sample=CFG["do_sample"])

        best = None

        for raw in raw_cands:
            result, log = evaluate_candidate(nlq, raw)
            _log(history, step=step, **log)

            if log["phase"] == "exec_fail":
                last_error = log.get("error")
                last_failed_sql = log.get("sql")

            if result:
                sql, score = result
                if (best is None) or (score > best[1]):
                    best = (sql, score)

        if best:
            sql, score = best
            _log(history, step=step, phase="final", sql=sql, score=score)
            return sql, history

        if CFG["enable_repair"] and last_error and last_failed_sql:
            repaired, repinfo = repair_sql(nlq, last_failed_sql, last_error, schema_text)
            _log(history, step=step, phase="repair", bad_sql=last_failed_sql, error=last_error, **repinfo)

            if repaired:
                meta = runner.run(repaired, capture_df=False)
                _log(history, step=step, phase="repair_exec", sql=repaired, success=meta.success, error=meta.error)
                if meta.success:
                    ok, why = intent_constraints(nlq, repaired)
                    _log(history, step=step, phase="intent_check", sql=repaired, ok=ok, reason=why)
                    if ok:
                        _log(history, step=step, phase="final", sql=repaired, score="repair_accept")
                        return repaired, history

            observation = f"Previous SQL failed: {last_error}. Revise tables/joins and try again."
        else:
            observation = "No executable candidates. Try a simpler join path."

        _log(history, step=step, phase="observation", obs=observation)

    if schema_summary is not None:
        fallback = vanilla_candidate(
            nlq=nlq,
            schema_summary=schema_summary,
            tok=tok,
            model=model,
            exemplars=exemplars or [],
        )
        if fallback:
            _log(history, step=CFG["max_steps"], phase="fallback", sql=fallback)
            return fallback, history

    _log(history, step=CFG["max_steps"], phase="fail", reason="No valid SQL found")
    return "", history


## EX Troubleshooting Checklist

If EX is low but VA is high, the error is usually *semantic alignment* (projection, intent, join choice).

**Quick checks:**
- **Projection drift**: NLQ lists fields but SQL returns extras or wrong order → tighten `enforce_projection_contract`.
- **Wrong intent**: list questions returning aggregates or groupings → check `intent_constraints`.
- **Wrong table/join**: NLQ terms not reflected in SQL tables → verify schema‑subset prompt and join hints.
- **Literal mismatch**: NLQ mentions a literal (e.g., ‘USA’, ‘San Francisco’) but SQL misses it.

**Debug workflow:**
1. Run quick check on 5–10 items.
2. Inspect trace phases: `clean → exec → intent` to locate failure.
3. Adjust projection/intent/schema subset before touching repair.


In [None]:
# 8) Quick sanity check on a few items
schema_text = SCHEMA_SUMMARY
for sample in test_set[:5]:
    nlq = sample["nlq"]
    gold = sample["sql"]
    pred, trace = react_sql(
        nlq=nlq,
        schema_text=schema_text,
        schema_summary=SCHEMA_SUMMARY,
        exemplars=test_set[:3],
    )
    print("NLQ:", nlq)
    print("PRED:", pred)
    print("GOLD:", gold)
    print("TRACE LEN:", len(trace))
    print("-" * 80)


### Stage 3 Interpretation (29 Jan 2026)

- **Valid SQL stability:** Stage 3 generally returns executable SQL; remaining issues are **projection bloat** (extra columns), and **unnecessary ORDER BY/GROUP BY**.
- **Metric impact:** These are EM regressions more than EX regressions. Use clamps + final normalization to keep outputs canonical.
- **Trace logging upgrade:** The ReAct loop now logs **raw → cleaned → post‑clamp → exec error → repair attempt**, so failures can be attributed to generation vs cleaning vs execution vs repair.


In [None]:
# === Test Suite Accuracy (TS) evaluation ===
# Based on distilled test suites idea: compare pred vs gold across multiple perturbed DBs.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterable, Optional
import re
import math
from sqlalchemy import text
from sqlalchemy.engine import Engine

TS_ORDER_BY_RE = re.compile(r"(?is)\border\s+by\b")

def _has_order_by(sql: str) -> bool:
    return bool(TS_ORDER_BY_RE.search(sql or ""))

def _coerce_cell(x: Any) -> Any:
    """Normalize SQL result cells for robust equality."""
    if x is None:
        return None
    if isinstance(x, float):
        if math.isnan(x):
            return "NaN"
        return round(x, 10)
    return x

def _normalize_rows(rows: Iterable[Iterable[Any]]) -> list[tuple[Any, ...]]:
    out = []
    for r in rows:
        out.append(tuple(_coerce_cell(v) for v in r))
    return out

def _sorted_rows(rows: list[tuple[Any, ...]]) -> list[tuple[Any, ...]]:
    return sorted(rows, key=lambda t: tuple("" if v is None else v for v in t))

@dataclass
class QueryRun:
    ok: bool
    rows: Optional[list[tuple[Any, ...]]] = None
    error: Optional[str] = None

def run_select(engine: Engine, sql: str, max_rows: int = 2000) -> QueryRun:
    """Execute SELECT and fetch up to max_rows. Returns rows as tuples."""
    try:
        with engine.connect() as conn:
            res = conn.execute(text(sql))
            fetched = res.fetchmany(max_rows)
            rows = _normalize_rows(fetched)
        return QueryRun(ok=True, rows=rows)
    except Exception as e:
        return QueryRun(ok=False, rows=None, error=str(e))

def results_match(
    gold_rows: list[tuple[Any, ...]],
    pred_rows: list[tuple[Any, ...]],
    ordered: bool,
) -> bool:
    """Compare result sets; ordered if ORDER BY exists."""
    if ordered:
        return gold_rows == pred_rows
    return _sorted_rows(gold_rows) == _sorted_rows(pred_rows)

def test_suite_accuracy_for_item(
    make_engine_fn,
    suite_db_names: list[str],
    gold_sql: str,
    pred_sql: str,
    *,
    max_rows: int = 2000,
    strict_gold: bool = True,
) -> tuple[int, dict]:
    """
    Returns (ts_pass, debug_info)

    strict_gold=True:
      if gold fails on any suite db, treat as TS=0 (suite generation bug / invalid gold on that db)
    strict_gold=False:
      ignore suite DBs where gold fails (uses remaining DBs)
    """
    ordered = _has_order_by(gold_sql) or _has_order_by(pred_sql)

    per_db = []
    usable = 0
    all_ok = True

    for db in suite_db_names:
        eng = make_engine_fn(db)

        g = run_select(eng, gold_sql, max_rows=max_rows)
        p = run_select(eng, pred_sql, max_rows=max_rows)

        if not g.ok:
            per_db.append({
                "db": db,
                "gold_ok": False,
                "pred_ok": p.ok,
                "gold_error": g.error,
                "pred_error": p.error if not p.ok else None,
                "match": False,
            })
            if strict_gold:
                all_ok = False
            continue

        usable += 1

        if not p.ok:
            per_db.append({
                "db": db,
                "gold_ok": True,
                "pred_ok": False,
                "gold_error": None,
                "pred_error": p.error,
                "match": False,
            })
            all_ok = False
            continue

        match = results_match(g.rows or [], p.rows or [], ordered=ordered)
        per_db.append({
            "db": db,
            "gold_ok": True,
            "pred_ok": True,
            "match": match,
            "ordered_compare": ordered,
            "gold_sample": (g.rows or [])[:10],
            "pred_sample": (p.rows or [])[:10],
        })
        if not match:
            all_ok = False

    if not strict_gold and usable == 0:
        all_ok = False

    ts = 1 if all_ok else 0
    debug = {"ordered_compare": ordered, "usable_dbs": usable, "per_db": per_db}
    return ts, debug


In [None]:
# === Quick test toggles (set before full eval) ===
# Use small values to sanity‑check TS/EX before full runs.
QUICK_LIMIT = 20   # number of NLQs to evaluate (set None for full set)
TS_N = 3           # number of TS DBs (set 10 for full TS)
MAX_ROWS_TS = 500  # row cap per query in TS (raise for full)


## TS Smoke‑Test Checklist (fast sanity run)

Before running full TS (200×10 DBs), do a quick sanity check:
- Set `QUICK_LIMIT = 5`
- Set `TS_N = 3`
- Set `MAX_ROWS_TS = 500`

After the run, check:
- Cases where **EX=0 but TS=1** (likely column order/label mismatch).
- Cases where **EX=1 but TS=0** (rare; likely truncation or ORDER BY).
- Cases where **VA=1, EX=0, TS=0** (wrong semantics).

If these look sensible, scale to full run:
`QUICK_LIMIT=None`, `TS_N=10`, `MAX_ROWS_TS=2000`.


In [None]:
# 9) Full ReAct-style evaluation (VA/EX/EM) over test_set
from nl2sql.eval import execution_accuracy
results = []
from functools import lru_cache

TS_PREFIX = "classicmodels_ts"
TS_N = TS_N
SUITE_DBS = [f"{TS_PREFIX}_{i:02d}" for i in range(1, TS_N + 1)]

@lru_cache(maxsize=32)
def make_engine_cached(db_name: str) -> Engine:
    return make_engine(db_name)

def make_engine_fn(db_name: str) -> Engine:
    return make_engine_cached(db_name)

LIMIT = QUICK_LIMIT  # override from quick toggles
items = test_set[:LIMIT] if LIMIT else test_set
schema_text = SCHEMA_SUMMARY

for i, sample in enumerate(items, start=1):
    nlq = sample["nlq"]
    gold_sql = sample["sql"]

    pred_sql, trace = react_sql(
        nlq=nlq,
        schema_text=schema_text,
        schema_summary=SCHEMA_SUMMARY,
        exemplars=test_set[:3],
    )

    ts, ts_debug = test_suite_accuracy_for_item(
        make_engine_fn=make_engine_fn,
        suite_db_names=SUITE_DBS,
        gold_sql=gold_sql,
        pred_sql=pred_sql,
        max_rows=MAX_ROWS_TS,
        strict_gold=True,
    )

    pred_clean = pred_sql.strip().rstrip(";").lower()
    gold_clean = gold_sql.strip().rstrip(";").lower()
    em = int(pred_clean == gold_clean)
    va = ex = 0

    try:
        meta = runner.run(pred_sql)
        va = int(meta.success)
        if meta.success:
            ex_ok, _, _ = execution_accuracy(engine=engine, pred_sql=pred_sql, gold_sql=gold_sql)
            ex = int(ex_ok)
    except Exception:
        va = 0
        ex = 0

    results.append({
        "nlq": nlq,
        "gold_sql": gold_sql,
        "pred_sql": pred_sql,
        "va": va,
        "em": em,
        "ex": ex,
        "ts": ts,
        "ts_debug": ts_debug,
        "trace": trace,
    })

    if i % 10 == 0:
        print(f"Processed {i}/{len(items)}")

va_rate = sum(r["va"] for r in results) / len(results)
ex_rate = sum(r["ex"] for r in results) / len(results)
em_rate = sum(r["em"] for r in results) / len(results)
    ts_rate = sum(r["ts"] for r in results) / len(results)
print("ReAct VA:", va_rate, "EX:", ex_rate, "EM:", em_rate, "TS:", ts_rate)

Path("results/agent").mkdir(parents=True, exist_ok=True)
save_path = Path("results/agent/results_react_200.json")
save_path.write_text(
    json.dumps({
        "va_rate": va_rate,
        "ex_rate": ex_rate,
        "em_rate": em_rate,
        "ts_rate": ts_rate,
        "items": results,
    }, ensure_ascii=False, indent=2),
    encoding="utf-8",
)
print("Saved to", save_path)
