# Baseline Eval


## Summary
- Purpose: establish the prompting-only baseline for NL->SQL (`k=0` and `k=3`).
- Scope: runs on the fixed 200-item ClassicModels benchmark for paired comparisons.
- Outputs: baseline run JSON files under `results/baseline/` for downstream analysis.


In [None]:
%%bash
set -e
export PIP_DEFAULT_TIMEOUT=120

# Clean conflicting preinstalls
pip uninstall -y torch torchvision torchaudio bitsandbytes triton transformers accelerate peft trl datasets numpy pandas fsspec requests google-auth scipy scikit-learn || true

# Base deps
pip install -q --no-cache-dir --force-reinstall \
  numpy==1.26.4 pandas==2.2.1 scipy scikit-learn \
  fsspec==2024.5.0 requests==2.31.0 google-auth==2.43.0

# Torch + CUDA 12.1
pip install -q --no-cache-dir --force-reinstall \
  torch==2.3.1+cu121 torchvision==0.18.1+cu121 torchaudio==2.3.1+cu121 \
  --index-url https://download.pytorch.org/whl/cu121

# bitsandbytes + triton + HF stack
pip install -q --no-cache-dir --force-reinstall \
  bitsandbytes==0.43.3 triton==2.3.1 \
  transformers==4.44.2 accelerate==0.33.0 peft==0.17.0 trl==0.9.6 datasets==2.20.0

echo "Setup complete. Restart runtime once, then run the rest of the notebook top-to-bottom."


## 0) Bootstrap


In [None]:
import os, sys, shutil
from pathlib import Path

# If this notebook is opened directly in Colab (not from a cloned repo), clone the repo first.
if Path("data/classicmodels_test_200.json").exists() is False and Path("/content").exists():
    repo_dir = Path("/content/NLtoSQL")
    if repo_dir.exists():
        shutil.rmtree(repo_dir)
    !git clone https://github.com/MacKenzieOBrian/NLtoSQL.git "{repo_dir}"
    os.chdir(repo_dir)

# Ensure repo root is on sys.path for `import nl2sql`
sys.path.insert(0, os.getcwd())
print("cwd:", os.getcwd())


## 1) Install


In [None]:
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    !pip -q install -r requirements.txt
else:
    print("Not in Colab; ensure requirements are installed.")


In [None]:
# Colab-only: authenticate with GCP (safe to skip locally)
try:
    from google.colab import auth
except ModuleNotFoundError:
    auth = None

if auth:
    auth.authenticate_user()
else:
    print("Not running in Colab; ensure ADC or service account auth is configured.")


In [None]:
# Hugging Face auth (gated model)
import os

hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
if hf_token:
    os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
    print("Using HF token from env")
else:
    try:
        from huggingface_hub import notebook_login
        notebook_login()
    except Exception as e:
        print("HF auth not configured:", e)


## 2) Data + DB


In [None]:
import json
from getpass import getpass

INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME", "classicmodels")

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

print("Using DB:", DB_NAME)

test_set = json.loads(open("data/classicmodels_test_200.json", "r", encoding="utf-8").read())
print("Loaded test items:", len(test_set))


In [None]:
from nl2sql.db import create_engine_with_connector

engine, connector = create_engine_with_connector(
    instance_connection_name=INSTANCE_CONNECTION_NAME,
    user=DB_USER,
    password=DB_PASS,
    db_name=DB_NAME,
)

print("Engine ready")


## 3) Load Model


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(MODEL_ID, token=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Try 4-bit loading (fallback to fp/bf16)
try:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    print("Attempting 4-bit quantized load...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=True,
    )
except Exception as e:
    print("4-bit load failed, falling back. Error:")
    print(e)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
        token=True,
    )

model.generation_config.do_sample = False
model.generation_config.num_beams = 1

print("Model device:", model.device)


## 4) Schema Context


In [None]:
from nl2sql.schema import build_schema_summary

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME, max_cols_per_table=50)
print("Schema summary built (chars):", len(SCHEMA_SUMMARY))


## 5) Run Controls


In [None]:
# Orchestration helpers for baseline sweep runs.
# Core NL->SQL logic stays in nl2sql imports below.
import re
import subprocess
import shutil
from functools import lru_cache
from datetime import datetime, timezone
from pathlib import Path

import pandas as pd

from nl2sql.eval import eval_run
from nl2sql.db import create_engine_with_connector
import nl2sql.prompting as prompting_mod

DEFAULT_SYSTEM_INSTRUCTIONS = prompting_mod.SYSTEM_INSTRUCTIONS


def _model_alias_from_id(model_id: str) -> str:
    tail = (model_id or "model").split("/")[-1]
    alias = re.sub(r"[^a-z0-9]+", "_", tail.lower()).strip("_")
    return alias or "model"


# Prompt variants for controlled prompt-ablation runs.
PROMPT_VARIANTS = {
    "default": DEFAULT_SYSTEM_INSTRUCTIONS,
    "schema_only_minimal": """You are an expert data analyst writing MySQL queries.
Given the database schema and a natural language question, write a single SQL SELECT query.

Rules:
- Output ONLY SQL (no explanation, no markdown).
- Output exactly ONE statement, starting with SELECT.
- Use only tables/columns listed in the schema.
""",
    "no_routing_hints": DEFAULT_SYSTEM_INSTRUCTIONS.split("- Routing hints:")[0].rstrip(),
}

# Schema truncation variants for schema-ablation experiments.
def schema_variant_text(schema_text: str, variant: str) -> str:
    lines = schema_text.splitlines()
    if variant == "full":
        return schema_text
    if variant == "first_80_lines":
        return "\n".join(lines[:80])
    if variant == "first_40_lines":
        return "\n".join(lines[:40])
    raise ValueError(f"Unknown SCHEMA_VARIANT: {variant}")

# Exemplar-pool strategies for few-shot ablations.
def exemplar_pool_for_strategy(items: list[dict], strategy: str) -> list[dict]:
    if strategy == "all":
        return list(items)

    def _sql(x):
        return str(x.get("sql", "")).strip()

    def _is_join(sql: str) -> bool:
        s = sql.lower()
        return " join " in f" {s} "

    def _is_agg(sql: str) -> bool:
        return bool(re.search(r"\b(sum|avg|count|min|max)\s*\(", sql.lower()))

    if strategy == "brief_sql":
        ranked = sorted(items, key=lambda x: len(_sql(x)))
        keep = max(50, int(0.4 * len(ranked)))
        pool = ranked[:keep]
    elif strategy == "join_heavy":
        pool = [x for x in items if _is_join(_sql(x))]
    elif strategy == "agg_heavy":
        pool = [x for x in items if _is_agg(_sql(x))]
    else:
        raise ValueError(f"Unknown EXEMPLAR_STRATEGY: {strategy}")

    return pool if len(pool) >= 10 else list(items)

# Capture git commit for run provenance (optional).
try:
    commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
except Exception:
    commit = "unknown"

# Base metadata attached to each saved run JSON.
run_metadata_base = {
    "commit": commit,
    "model_id": MODEL_ID,
    "model_alias": _model_alias_from_id(MODEL_ID),
    "notebook": "02_baseline_prompting_eval.ipynb",
    "method": "baseline",
}

# Ensure output root exists before writing run artifacts.
Path("results/baseline").mkdir(parents=True, exist_ok=True)


# Optional Colab helper: copy completed run artifacts to Drive.
def persist_run_to_drive(
    run_dir: str | Path,
    model_alias: str,
    k_values: list[int],
    run_tag: str,
    persist_root: str = "/content/drive/MyDrive/nl2sql_persistent_runs",
):
    # Copy baseline run outputs to Google Drive to survive Colab disconnects.
    run_dir = Path(run_dir)
    root = Path(persist_root)
    root.mkdir(parents=True, exist_ok=True)

    stamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%SZ")
    dst = root / f"{run_dir.name}_{stamp}"
    dst.mkdir(parents=True, exist_ok=True)

    for p in run_dir.glob("*"):
        if p.is_file():
            shutil.copy2(p, dst / p.name)

    mf_src = Path("results/baseline/model_family")
    mf_dst = dst / "model_family"
    mf_dst.mkdir(parents=True, exist_ok=True)
    for k in k_values:
        mf = mf_src / f"{model_alias}_k{k}.json"
        if mf.exists():
            shutil.copy2(mf, mf_dst / mf.name)

    (dst / "backup_manifest.txt").write_text(
        "\n".join([
            f"run_dir={run_dir}",
            f"backup_dir={dst}",
            f"run_tag={run_tag}",
            f"model_alias={model_alias}",
            f"k_values={k_values}",
        ]),
        encoding="utf-8",
    )
    return dst

# Main sweep runner: executes k/seed grid and writes JSON+CSV artifacts.
def run_baseline_grid(
    *,
    k_values: list[int],
    seeds: list[int],
    run_tag: str,
    prompt_variant: str,
    schema_variant: str,
    exemplar_strategy: str,
    limit: int | None = None,
    copy_canonical: bool = True,
    copy_model_family: bool = True,
    model_alias: str | None = None,
    enable_ts_for_k: set[int] | None = None,
    ts_n: int = 10,
    ts_prefix: str = "classicmodels_ts",
    ts_max_rows: int = 500,
):
    # Basic guardrails for reproducible runs.
    if not seeds:
        raise ValueError("Provide at least one seed")

    if prompt_variant not in PROMPT_VARIANTS:
        raise ValueError(f"Unknown PROMPT_VARIANT: {prompt_variant}")

    # Run directory is timestamped for traceability and collision safety.
    ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%SZ")
    run_dir = Path("results/baseline/runs") / f"{run_tag}_{ts}"
    run_dir.mkdir(parents=True, exist_ok=True)

    schema_used = schema_variant_text(SCHEMA_SUMMARY, schema_variant)
    exemplar_pool = exemplar_pool_for_strategy(test_set, exemplar_strategy)
    resolved_model_alias = model_alias or run_metadata_base.get("model_alias") or _model_alias_from_id(MODEL_ID)

    ts_enabled_k = set(enable_ts_for_k or set())
    ts_suite_db_names = (
        [f"{ts_prefix}_{i:02d}" for i in range(1, ts_n + 1)]
        if ts_enabled_k and ts_n > 0
        else None
    )

    # Lazy connector cache for optional TS evaluation databases.
    ts_connectors = {}

    @lru_cache(maxsize=32)
    def _make_engine_cached(db_name: str):
        eng, conn = create_engine_with_connector(
            instance_connection_name=INSTANCE_CONNECTION_NAME,
            user=DB_USER,
            password=DB_PASS,
            db_name=db_name,
        )
        ts_connectors[db_name] = conn
        return eng

    def _make_engine_fn(db_name: str):
        return _make_engine_cached(db_name)

    rows = []
    primary_seed = seeds[0]  # first seed is used for canonical/model-family copies

    # Prompt override is temporary and restored in finally.
    old_prompt = prompting_mod.SYSTEM_INSTRUCTIONS
    prompting_mod.SYSTEM_INSTRUCTIONS = PROMPT_VARIANTS[prompt_variant]

    try:
        for k in k_values:
            # Keep k=0 tied to primary seed for canonical zero-shot consistency.
            seed_list = [primary_seed] if k == 0 else seeds
            for seed in seed_list:
                save_path = run_dir / f"results_k{k}_seed{seed}.json"

                run_meta = dict(run_metadata_base)
                run_meta.update({
                    "run_tag": run_tag,
                    "k": k,
                    "seed": seed,
                    "prompt_variant": prompt_variant,
                    "schema_variant": schema_variant,
                    "exemplar_strategy": exemplar_strategy,
                    "exemplar_pool_size": len(exemplar_pool),
                    "model_alias": resolved_model_alias,
                    "ts_enabled": bool(k in ts_enabled_k),
                    "ts_for_k_values": sorted(ts_enabled_k),
                    "ts_n": ts_n if ts_suite_db_names else 0,
                })

                items = eval_run(
                    test_set=test_set,
                    exemplar_pool=exemplar_pool,
                    k=k,
                    limit=limit,
                    seed=seed,
                    engine=engine,
                    model=model,
                    tokenizer=tok,
                    schema_summary=schema_used,
                    save_path=str(save_path),
                    run_metadata=run_meta,
                    ts_suite_db_names=ts_suite_db_names if k in ts_enabled_k else None,
                    ts_make_engine_fn=_make_engine_fn if k in ts_enabled_k else None,
                    ts_max_rows=ts_max_rows,
                    avoid_exemplar_leakage=True,
                )

                # Aggregate per-run rates for quick table summaries.
                n = len(items)
                va = sum(int(x.va) for x in items) / max(n, 1)
                em = sum(int(x.em) for x in items) / max(n, 1)
                ex = sum(int(x.ex) for x in items) / max(n, 1)
                ts_values = [int(x.ts) for x in items if getattr(x, "ts", None) is not None]
                ts_rate = (sum(ts_values) / len(ts_values)) if ts_values else None

                rows.append({
                    "run_tag": run_tag,
                    "prompt_variant": prompt_variant,
                    "schema_variant": schema_variant,
                    "exemplar_strategy": exemplar_strategy,
                    "exemplar_pool_size": len(exemplar_pool),
                    "k": k,
                    "seed": seed,
                    "n": n,
                    "va_rate": va,
                    "em_rate": em,
                    "ex_rate": ex,
                    "ts_rate": ts_rate,
                    "ts_n": len(ts_values),
                    "json_path": str(save_path),
                })

                if seed == primary_seed and k in {0, 3}:
                    if copy_canonical:
                        target = (
                            Path("results/baseline/results_zero_shot_200.json")
                            if k == 0
                            else Path("results/baseline/results_few_shot_k3_200.json")
                        )
                        target.parent.mkdir(parents=True, exist_ok=True)
                        shutil.copy2(save_path, target)
                        print(f"Updated canonical file: {target}")

                    if copy_model_family:
                        model_target = Path("results/baseline/model_family") / f"{resolved_model_alias}_k{k}.json"
                        model_target.parent.mkdir(parents=True, exist_ok=True)
                        shutil.copy2(save_path, model_target)
                        print(f"Updated model-family file: {model_target}")
    finally:
        for conn in ts_connectors.values():
            try:
                conn.close()
            except Exception:
                pass
        prompting_mod.SYSTEM_INSTRUCTIONS = old_prompt

    # Save per-run table and per-k mean/std summary.
    df = pd.DataFrame(rows).sort_values(["k", "seed"]).reset_index(drop=True)
    df.to_csv(run_dir / "grid_summary.csv", index=False)

    agg = (
        df.groupby(["prompt_variant", "schema_variant", "exemplar_strategy", "k"], as_index=False)
        .agg(
            runs=("seed", "count"),
            va_mean=("va_rate", "mean"),
            va_std=("va_rate", "std"),
            em_mean=("em_rate", "mean"),
            em_std=("em_rate", "std"),
            ex_mean=("ex_rate", "mean"),
            ex_std=("ex_rate", "std"),
            ts_mean=("ts_rate", "mean"),
            ts_std=("ts_rate", "std"),
        )
    )
    agg.to_csv(run_dir / "grid_summary_by_k.csv", index=False)

    print("Saved grid run to:", run_dir)
    return df, agg, run_dir


### 5A) Run Plan


In [None]:
# ============================
# Plan-driven baseline controls
# ============================
# Edit only this block for most experiments.
# Keep one RUN_PLAN active per execution to isolate effects clearly.
RUN_PLAN = "quick"  # quick | k5_core | seed_backfill | full_sweep | ts_k3 | custom

PROMPT_VARIANT = "default"
SCHEMA_VARIANT = "full"
EXEMPLAR_STRATEGY = "all"

MODEL_ALIAS = _model_alias_from_id(MODEL_ID)
COPY_MODEL_FAMILY = True
COPY_CANONICAL = False

TS_N = 10
TS_PREFIX = "classicmodels_ts"
TS_MAX_ROWS = 500

CUSTOM_PLAN = {
    "k_values": [0, 3],
    "seeds": [7],
    "run_tag": "baseline_custom",
    "enable_ts": False,
}

# Resolve plan -> concrete k/seed/run_tag values.
if RUN_PLAN == "quick":
    K_VALUES = [0, 3]
    SEEDS = [7]
    RUN_TAG = "baseline_main"
    ENABLE_TS = False
elif RUN_PLAN == "k5_core":
    K_VALUES = [5]
    SEEDS = [7]
    RUN_TAG = "baseline_k5_core"
    ENABLE_TS = False
elif RUN_PLAN == "seed_backfill":
    K_VALUES = [0, 3, 5]
    SEEDS = [17, 27]
    RUN_TAG = "baseline_seed_backfill"
    ENABLE_TS = False
elif RUN_PLAN == "full_sweep":
    K_VALUES = [0, 3, 5]
    SEEDS = [7, 17, 27]
    RUN_TAG = "baseline_e1_k_sweep"
    ENABLE_TS = False
elif RUN_PLAN == "ts_k3":
    K_VALUES = [3]
    SEEDS = [7]
    RUN_TAG = "baseline_ts_k3"
    ENABLE_TS = True
elif RUN_PLAN == "custom":
    K_VALUES = list(CUSTOM_PLAN["k_values"])
    SEEDS = list(CUSTOM_PLAN["seeds"])
    RUN_TAG = str(CUSTOM_PLAN["run_tag"])
    ENABLE_TS = bool(CUSTOM_PLAN["enable_ts"])
else:
    raise ValueError(f"Unknown RUN_PLAN: {RUN_PLAN}")

# TS is typically enabled only for k=3 semantic robustness checks.
TS_FOR_K_VALUES = [3]

# Run card is a quick sanity check before launching a long run.
print("Baseline run card:")
print({
    "run_plan": RUN_PLAN,
    "run_tag": RUN_TAG,
    "model_id": MODEL_ID,
    "model_alias": MODEL_ALIAS,
    "k_values": K_VALUES,
    "seeds": SEEDS,
    "prompt_variant": PROMPT_VARIANT,
    "schema_variant": SCHEMA_VARIANT,
    "exemplar_strategy": EXEMPLAR_STRATEGY,
    "enable_ts": ENABLE_TS,
})

# Colab persistence toggle
PERSIST_TO_DRIVE = True
DRIVE_PERSIST_ROOT = "/content/drive/MyDrive/nl2sql_persistent_runs"


In [None]:
# Demo: End-to-end NLQ -> faulty SQL -> cleaned SQL (single cell)
from IPython.display import display, HTML
import pandas as pd
import html

from nl2sql.core.prompting import make_few_shot_messages
from nl2sql.agent.constraint_policy import build_constraints
from nl2sql.core.llm import debug_extract_first_select
from nl2sql.core.postprocess import debug_guarded_postprocess


# Small helper to print section headers in the notebook output.
def show_title(text):
    display(HTML(f"<h3 style='margin:12px 0 6px 0'>{html.escape(text)}</h3>"))


# Small helper to render SQL/text in a boxed monospace block.
def show_pre(text, label=None):
    label_html = f"<div style='font-weight:600;margin-bottom:6px'>{html.escape(label)}</div>" if label else ""
    display(HTML(
        "<div style='border:1px solid #ddd;border-radius:8px;padding:10px 12px;margin:8px 0'>"
        f"{label_html}"
        f"<pre style='white-space:pre-wrap;margin:0;font-family:ui-monospace, SFMono-Regular, Menlo, Consolas, monospace'>{html.escape(str(text))}</pre>"
        "</div>"
    ))


# Convert postprocess steps into a small readable table.
def steps_df(pp):
    return pd.DataFrame([
        {
            "changed": "yes" if s["changed"] else "no",
            "stage": s["stage"],
            "note": s.get("note", ""),
        }
        for s in pp["steps"]
    ])


# Demo NLQs: one implicit question and one explicit field-list question.
DEMO_NLQ_IMPLICIT = "List all customer names in France"
DEMO_NLQ_EXPLICIT = "List contact last name, customer name, and customer number for customers in France"

# Use real schema text if available; otherwise use a minimal fallback.
schema_text = (
    SCHEMA_SUMMARY
    if "SCHEMA_SUMMARY" in globals() and isinstance(SCHEMA_SUMMARY, str) and SCHEMA_SUMMARY.strip()
    else "Table customers (customerNumber INT, customerName TEXT, contactLastName TEXT, country TEXT, creditLimit REAL)"
)

# Pull a couple of real exemplars when the benchmark is loaded.
exemplars = []
if "test_set" in globals() and isinstance(test_set, list):
    exemplars = [x for x in test_set[:2] if isinstance(x, dict) and "nlq" in x and "sql" in x]

# Build the same style of messages used by the real pipeline.
messages = make_few_shot_messages(schema=schema_text, exemplars=exemplars, nlq=DEMO_NLQ_IMPLICIT)
constraints_implicit = build_constraints(DEMO_NLQ_IMPLICIT, schema_text)

# Step 1: show the NLQ and the prompt context.
show_title("Step 1 - NLQ and prompt context")
display(pd.DataFrame([
    {
        "nlq": DEMO_NLQ_IMPLICIT,
        "schema_lines": len(schema_text.splitlines()),
        "exemplars_used": len(exemplars),
        "message_count": len(messages),
        "explicit_fields": constraints_implicit.get("explicit_fields"),
    }
]))

display(pd.DataFrame([
    {
        "role": m.get("role"),
        "content_preview": str(m.get("content", "")).replace("\n", " ")[:140],
    }
    for m in messages[:6]
]))

# Step 2: simulate a noisy/faulty model output (on purpose).
FAULTY_TEXT = """Model draft + noise:
select from the options above

SQL:
SELECT c.customerNumber, c.customerName, c.contactLastName, c.creditLimit
FROM customers c
WHERE c.country = 'France'
ORDER BY c.customerName DESC
LIMIT 5;

Extra explanation after SQL.
"""

show_title("Step 2 - Simulated faulty SQL draft")
show_pre(FAULTY_TEXT, "Faulty model output (simulated)")

# Step 3: run extraction logic to pick the best SQL candidate.
show_title("Step 3 - Extraction debug")
extract_debug = debug_extract_first_select(FAULTY_TEXT)
selected_sql = extract_debug.get("selected_sql") or FAULTY_TEXT

display(pd.DataFrame([
    {
        "candidate": i,
        "accepted": c.get("accepted"),
        "reject_reason": c.get("reject_reason"),
        "from_target": c.get("from_target"),
        "candidate_sql": c.get("candidate_sql"),
    }
    for i, c in enumerate(extract_debug.get("candidates", []), start=1)
]))
show_pre(selected_sql, "Selected SQL candidate")

# Step 4A: clean SQL for implicit-field question behavior.
show_title("Step 4A - Cleaning trace (implicit fields)")
pp_a = debug_guarded_postprocess(
    selected_sql,
    DEMO_NLQ_IMPLICIT,
    explicit_fields=constraints_implicit.get("explicit_fields") if constraints_implicit.get("explicit_projection") else None,
    required_fields=constraints_implicit.get("required_output_fields"),
)
display(steps_df(pp_a))
show_pre(pp_a["final_sql"], "Final cleaned SQL (implicit)")

# Step 4B: clean SQL for explicit-field question behavior.
show_title("Step 4B - Cleaning trace (explicit fields)")
pp_b = debug_guarded_postprocess(
    selected_sql,
    DEMO_NLQ_EXPLICIT,
    explicit_fields=["contactLastName", "customerName", "customerNumber"],
)
display(steps_df(pp_b))
show_pre(pp_b["final_sql"], "Final cleaned SQL (explicit)")


In [None]:
# Execute selected sweep.
baseline_grid, baseline_by_k, baseline_run_dir = run_baseline_grid(
    k_values=K_VALUES,
    seeds=SEEDS,
    run_tag=RUN_TAG,
    prompt_variant=PROMPT_VARIANT,
    schema_variant=SCHEMA_VARIANT,
    exemplar_strategy=EXEMPLAR_STRATEGY,
    limit=None,
    copy_canonical=COPY_CANONICAL,
    copy_model_family=COPY_MODEL_FAMILY,
    model_alias=MODEL_ALIAS,
    enable_ts_for_k=set(TS_FOR_K_VALUES) if ENABLE_TS else None,
    ts_n=TS_N,
    ts_prefix=TS_PREFIX,
    ts_max_rows=TS_MAX_ROWS,
)

print("\nPer-run rows:")
display(baseline_grid)
print("\nPer-k summary (mean/std across seeds):")
display(baseline_by_k)

# Persist artifacts so they survive Colab runtime disconnects.
if PERSIST_TO_DRIVE:
    try:
        backup_dir = persist_run_to_drive(
            run_dir=baseline_run_dir,
            model_alias=MODEL_ALIAS,
            k_values=K_VALUES,
            run_tag=RUN_TAG,
            persist_root=DRIVE_PERSIST_ROOT,
        )
        print("Persistent backup saved to:", backup_dir)
    except Exception as e:
        print("Drive backup skipped/failed:", e)


In [None]:
# 6) Quick canonical summary (reads saved k=0 and k=3 JSON outputs)
import json

zero = json.loads(open("results/baseline/results_zero_shot_200.json", "r", encoding="utf-8").read())
few  = json.loads(open("results/baseline/results_few_shot_k3_200.json", "r", encoding="utf-8").read())

print("Zero-shot:", "VA", round(zero["va_rate"], 3), "EM", round(zero.get("em_rate", 0.0), 3), "EX", round(zero["ex_rate"], 3))
print("Few-shot:",  "VA", round(few["va_rate"], 3),  "EM", round(few.get("em_rate", 0.0), 3),  "EX", round(few["ex_rate"], 3))
