<a href="https://colab.research.google.com/github/MacKenzieOBrian/NLtoSQL/blob/master/Copy_of_week7_minimal_test_notebook_clean.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cell 1 — Repo clone + clean workspace
This cell resets /content and reclones the GitHub repo to guarantee the notebook is running the latest tracked code and data files. This is standard practice in Colab to avoid stale state across sessions (old notebooks, cached files, partial clones) and to make results reproducible

In [None]:
import os, shutil

# Reset to a known-good directory
os.chdir("/content")

# Remove any partial clone
repo_dir = "/content/NLtoSQL"
if os.path.exists(repo_dir):
    shutil.rmtree(repo_dir)

# Fresh clone
!git clone https://github.com/MacKenzieOBrian/NLtoSQL.git "$repo_dir"
os.chdir(repo_dir)

# Sanity checks
!git rev-parse --short HEAD
!ls

Cloning into '/content/NLtoSQL'...
remote: Enumerating objects: 155, done.[K
remote: Counting objects: 100% (155/155), done.[K
remote: Compressing objects: 100% (112/112), done.[K
remote: Total 155 (delta 101), reused 89 (delta 43), pack-reused 0 (from 0)[K
Receiving objects: 100% (155/155), 173.25 KiB | 1.39 MiB/s, done.
Resolving deltas: 100% (101/101), done.
e9a062e
 ARCHITECTURE.md   LOGBOOK.md	     'week7_minimal_test_notebook (2).py'
 CONFIG.md	   NOTES.md	      week7_minimal_test_notebook_clean.ipynb
 data		   oldlogbook.txt
 DATA.md	   requirements.txt


# Cell 2 — GCP auth

In [None]:
try:
    from google.colab import auth
except ModuleNotFoundError:
    auth = None

if auth:
    auth.authenticate_user()
else:
    print("Not running in Colab; ensure GCP auth via gcloud/ADC or service account if needed.")

# Cell 3 — Project configuration (env vars / project id)



In [None]:
import os

project_id = "modified-enigma-476414-h9"  # TODO: change or move to env var in production
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id

print("GOOGLE_CLOUD_PROJECT set to:", os.environ["GOOGLE_CLOUD_PROJECT"])


GOOGLE_CLOUD_PROJECT set to: modified-enigma-476414-h9


# Cell 4 — Install dependencies from requirements.txt

This cell installs pinned dependencies so the runtime matches the documented toolchain.

In [None]:
import sys
# Prefer installing from pinned requirements.txt for reproducibility
!{sys.executable} -m pip install --upgrade pip
!{sys.executable} -m pip install -r requirements.txt




# Cell 5 — Imports & logging setup

This cell centralises imports and initialises logging. It’s standard practice to make the execution environment explicit and to ensure later cells can emit consistent, timestamped debug information during DB and model evaluation.

In [None]:
import os
import logging
import json
from datetime import datetime, timezone
from typing import Any, Dict, Optional

import pandas as pd
import sqlalchemy
from sqlalchemy import text
from sqlalchemy.engine import Engine

from google.cloud.sql.connector import Connector
from google.api_core import retry

import pymysql

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("nl2sql_db")

# Cell 6 — Connection parameters (prompt for secrets if missing)

This cell collects DB connection details securely (e.g., using getpass) if they’re not already provided via environment variables. This is standard practice for notebooks: it prevents credentials being stored in plaintext while keeping the workflow runnable.

In [None]:
from getpass import getpass

INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME", "classicmodels")

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()

if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()

if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

print("Using DB:", DB_NAME)

Enter INSTANCE_CONNECTION_NAME: modified-enigma-476414-h9:europe-west2:classicmodels
Enter DB_USER: root
Enter DB_PASS: ··········
Using DB: classicmodels


# Cell 7 — Cloud SQL connector + SQLAlchemy engine

This cell builds a secure DB engine using the Cloud SQL Connector and SQLAlchemy’s creator pattern. It’s standard practice for cloud-hosted databases because it avoids exposing DB endpoints directly and provides a stable interface for executing queries and retrieving schema metadata.

In [None]:
from contextlib import contextmanager

connector = Connector()

def getconn():
    """SQLAlchemy creator hook using the Cloud SQL connector."""
    return connector.connect(
        INSTANCE_CONNECTION_NAME,
        "pymysql",
        user=DB_USER,
        password=DB_PASS,
        db=DB_NAME,
    )

# Use a creator function so SQLAlchemy delegates connection creation to the connector
engine: Engine = sqlalchemy.create_engine(
    "mysql+pymysql://",
    creator=getconn,
    future=True
)

@contextmanager
def safe_connection(engine: Engine):
    """
    Context manager that yields a DB connection and ensures it gets closed.
    """
    conn = None
    try:
        conn = engine.connect()
        yield conn
    finally:
        if conn is not None:
            conn.close()

# Cell 8 — Schema exploration helpers
This cell defines functions to list tables and fetch column metadata from INFORMATION_SCHEMA. Schema introspection is standard in NL→SQL because models need explicit schema grounding; automating extraction avoids manual schema drift and keeps prompts aligned with the live database.

In [None]:
def list_tables(engine: Engine) -> list:
    """Return a list of table names in the current database."""
    with safe_connection(engine) as conn:
        result = conn.execute(text("SHOW TABLES;")).fetchall()
    return [r[0] for r in result]

def get_table_columns(engine: Engine, table_name: str) -> pd.DataFrame:
    """
    Return a DataFrame of columns for a given table.
    Includes column name, data type, nullability, and key info.
    """
    query = text("""
        SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY
        FROM INFORMATION_SCHEMA.COLUMNS
        WHERE TABLE_SCHEMA = :db AND TABLE_NAME = :table
        ORDER BY ORDINAL_POSITION
    """)
    with safe_connection(engine) as conn:
        df = pd.read_sql(query, conn, params={"db": DB_NAME, "table": table_name})
    return df


# Cell 9 — QueryRunner (read-only executor + logging)
This cell defines a controlled SQL execution wrapper that blocks destructive statements and logs execution metadata. This is standard in text-to-SQL evaluation: it protects the database, supports VA (validity) measurement, and provides structured error traces for analysis.

In [None]:
class QueryExecutionError(Exception):
    pass

def now_utc_iso() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")

class QueryRunner:
    """
    Execute generated SQL safely against the engine, capture results and metadata,
    and keep a history suitable for evaluation and error analysis.
    """
    def __init__(self, engine: Engine, max_rows: int = 1000, forbidden_tokens=None):
        self.engine = engine
        self.max_rows = max_rows
        self.history = []
        self.forbidden_tokens = forbidden_tokens or [
            "drop ", "delete ", "truncate ", "alter ", "create ", "update ", "insert "
        ]

    def _safety_check(self, sql: str) -> None:
        """Block obviously destructive statements."""
        lowered = (sql or "").strip().lower()
        if not lowered:
            raise QueryExecutionError("Empty SQL string")
        for token in self.forbidden_tokens:
            if token in lowered:
                raise QueryExecutionError(f"Destructive SQL token detected: {token.strip()}")

    def run(
        self,
        sql: str,
        params: Optional[Dict[str, Any]] = None,
        capture_df: bool = True,
    ) -> Dict[str, Any]:
        """
        Execute a SELECT-style query, returning metadata and an optional DataFrame preview.
        """
        entry = {
            "sql": sql,
            "params": params,
            "timestamp": now_utc_iso(),
            "success": False,
            "rowcount": 0,
            "exec_time_s": None,
            "error": None,
            "columns": None,
            "result_preview": None,
        }
        try:
            self._safety_check(sql)
            start = datetime.now(timezone.utc)

            with safe_connection(self.engine) as conn:
                result = conn.execute(sqlalchemy.text(sql), params or {})
                rows = result.fetchall()
                cols = list(result.keys())

            end = datetime.now(timezone.utc)
            exec_time = (end - start).total_seconds()

            df = None
            if capture_df:
                df = pd.DataFrame(rows, columns=cols)
                if len(df) > self.max_rows:
                    df = df.iloc[: self.max_rows]

            entry.update({
                "success": True,
                "rowcount": min(len(rows), self.max_rows),
                "exec_time_s": exec_time,
                "columns": cols,
                "result_preview": df,
            })
        except Exception as e:
            entry.update({
                "error": str(e),
                "success": False,
            })
        finally:
            self.history.append(entry)
        return entry

    def last(self):
        return self.history[-1] if self.history else None

    def save_history(self, path: str):
        """Persist history (without DataFrames) to JSON for later analysis."""
        serializable = []
        for h in self.history:
            s = {k: v for k, v in h.items() if k != "result_preview"}
            serializable.append(s)
        with open(path, "w", encoding="utf-8") as f:
            json.dump(serializable, f, indent=2, default=str)

# Cell 10 — Smoke tests (DB connectivity + schema)
This cell runs minimal “does it work” queries and prints schema samples. Smoke tests are standard practice to verify the experimental apparatus (DB access + schema + executor) before attributing downstream errors to the model.

In [None]:
def fetch_sample_customers(limit: int = 10) -> pd.DataFrame:
    """Quick sample query against the customers table."""
    q = text("SELECT customerNumber, customerName, country FROM customers LIMIT :limit;")
    with safe_connection(engine) as conn:
        df = pd.read_sql(q, conn, params={"limit": limit})
    return df

try:
    tables = list_tables(engine)
    logger.info("Tables in classicmodels: %s", tables)

    sample_df = fetch_sample_customers(5)
    display(sample_df)

    # Optionally print each table's schema (comment out if too verbose)
    for table_name in tables:
        print(f"\nSchema for table: {table_name}")
        df_columns = get_table_columns(engine, table_name)
        display(df_columns)

except Exception as e:
    logger.exception("Smoke test failed: %s", e)

Unnamed: 0,customerNumber,customerName,country
0,103,Atelier graphique,France
1,112,Signal Gift Stores,USA
2,114,"Australian Collectors, Co.",Australia
3,119,La Rochelle Gifts,France
4,121,Baane Mini Imports,Norway



Schema for table: customers


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,customerNumber,int,NO,PRI
1,customerName,varchar,NO,
2,contactLastName,varchar,NO,
3,contactFirstName,varchar,NO,
4,phone,varchar,NO,
5,addressLine1,varchar,NO,
6,addressLine2,varchar,YES,
7,city,varchar,NO,
8,state,varchar,YES,
9,postalCode,varchar,YES,



Schema for table: employees


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,employeeNumber,int,NO,PRI
1,lastName,varchar,NO,
2,firstName,varchar,NO,
3,extension,varchar,NO,
4,email,varchar,NO,
5,officeCode,varchar,NO,MUL
6,reportsTo,int,YES,MUL
7,jobTitle,varchar,NO,



Schema for table: offices


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,officeCode,varchar,NO,PRI
1,city,varchar,NO,
2,phone,varchar,NO,
3,addressLine1,varchar,NO,
4,addressLine2,varchar,YES,
5,state,varchar,YES,
6,country,varchar,NO,
7,postalCode,varchar,NO,
8,territory,varchar,NO,



Schema for table: orderdetails


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,orderNumber,int,NO,PRI
1,productCode,varchar,NO,PRI
2,quantityOrdered,int,NO,
3,priceEach,decimal,NO,
4,orderLineNumber,smallint,NO,



Schema for table: orders


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,orderNumber,int,NO,PRI
1,orderDate,date,NO,
2,requiredDate,date,NO,
3,shippedDate,date,YES,
4,status,varchar,NO,
5,comments,text,YES,
6,customerNumber,int,NO,MUL



Schema for table: payments


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,customerNumber,int,NO,PRI
1,checkNumber,varchar,NO,PRI
2,paymentDate,date,NO,
3,amount,decimal,NO,



Schema for table: productlines


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,productLine,varchar,NO,PRI
1,textDescription,varchar,YES,
2,htmlDescription,mediumtext,YES,
3,image,mediumblob,YES,



Schema for table: products


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,productCode,varchar,NO,PRI
1,productName,varchar,NO,
2,productLine,varchar,NO,MUL
3,productScale,varchar,NO,
4,productVendor,varchar,NO,
5,productDescription,text,NO,
6,quantityInStock,smallint,NO,
7,buyPrice,decimal,NO,
8,MSRP,decimal,NO,


# Cell 11 — Load static NLQ–SQL test set
This cell loads the benchmark dataset (classicmodels_test_200.json). Separating dataset loading is standard so evaluation is repeatable and the exact test items are fixed and version-controlled.

In [None]:
with open("data/classicmodels_test_200.json", "r", encoding="utf-8") as f:
    test_set = json.load(f)

print(f"Loaded {len(test_set)} test items from data/classicmodels_test_200.json")

Loaded 200 test items from data/classicmodels_test_200.json


# Cell 12 — Optional: validate gold SQL against the live DB
This cell executes the gold queries to confirm they all run (e.g., 200/200). This is standard evaluation hygiene: it prevents false negatives in EX/VA caused by incorrect reference SQL or schema mismatches.

In [None]:
from typing import List, Tuple

def validate_test_set(path: str = "data/classicmodels_test_200.json",
                      limit: Optional[int] = None) -> Tuple[list, list]:
    with open(path, "r", encoding="utf-8") as f:
        items = json.load(f)
    if limit:
        items = items[:limit]

    qr_local = QueryRunner(engine, max_rows=200)
    successes: List[int] = []
    failures: List[Dict[str, Any]] = []

    for idx, item in enumerate(items):
        meta = qr_local.run(item["sql"], capture_df=False)
        if meta["success"]:
            successes.append(idx)
        else:
            failures.append({
                "index": idx,
                "nlq": item.get("nlq"),
                "sql": item.get("sql"),
                "error": meta["error"],
            })

    print(f"Ran {len(items)} queries. Success: {len(successes)}. Failures: {len(failures)}.")
    if failures:
        print("Failures (first 5):")
        for f in failures[:5]:
            print(f)
    else:
        print("All queries succeeded in this run.")
    return successes, failures

# Uncomment to run a quick validation (e.g. on first 50)
successes, failures = validate_test_set(limit=50)

ERROR:asyncio:Unclosed client session
client_session: <aiohttp.client.ClientSession object at 0x7c2fff92f6b0>


Ran 50 queries. Success: 50. Failures: 0.
All queries succeeded in this run.


# Cell 13 — Hugging Face authentication

This cell authenticates to Hugging Face Hub for model access, especially for gated checkpoints.

In [None]:
from huggingface_hub import notebook_login
notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Cell 14 — Load model + deterministic smoke test
This cell loads the LLM once, configures deterministic generation defaults for evaluation, and runs a trivial “OK” test to confirm inference works. Model-load and smoke-test separation is standard practice to ensure later NL→SQL failures are prompt/data issues rather than model/runtime issues.

torch: runs the model on GPU and controls “no gradients” mode.

*   torch: runs the model on GPU and controls “no gradients” mode.
*   transformers: loads the tokenizer + model.
*   BitsAndBytesConfig: tells Transformers how to load the model in 4‑bit


# Model loading (try 4‑bit, else fallback)
The try: block attempts 4‑bit NF4 loading (smaller memory, faster to fit on Colab GPUs).

# Deterministic defaults
do_sample=False means “don’t roll dice, always pick the most likely next token” (good for repeatable evaluation).

num_beams=1 means no beam search (simple deterministic decoding).

# Smoke test
Builds a tiny chat conversation: system message + user asks for “OK”.


*   apply_chat_template(...) formats it the way Llama‑3‑Instruct expects.
* torch.no_grad() makes it inference-only (faster, less memory).
* model.generate(...) produces up to 3 tokens.

We slice out only the newly generated part and decode it, so it prints just OK.




In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(MODEL_ID, token=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Try 4-bit loading
try:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    print("Attempting 4-bit quantized load...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=True,
    )
except Exception as e:
    print("4-bit load failed, falling back to full-precision load.\nError:")
    print(e)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
        token=True,
    )

print("\nModel loaded!")
print("Device:", next(model.parameters()).device)

# Deterministic defaults WITHOUT warnings
model.generation_config.do_sample = False
model.generation_config.num_beams = 1
model.generation_config.temperature = 1.0
model.generation_config.top_p = 1.0
model.generation_config.top_k = 50

# Deterministic smoke test (chat-format + decode only new tokens)
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Reply with only the word OK."},
]
input_ids = tok.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
).to(model.device)

with torch.no_grad():
    out = model.generate(
        input_ids,
        max_new_tokens=3,
        do_sample=False,
        pad_token_id=tok.eos_token_id,
    )

gen_ids = out[0][input_ids.shape[-1]:]
print(tok.decode(gen_ids, skip_special_tokens=True).strip())


Loading tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/51.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Attempting 4-bit quantized load...


config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]


Model loaded!
Device: cuda:0
OK


# Cell 15 — Batch baseline evaluation (zero‑shot vs few‑shot)
This cell runs a small evaluation loop over my queries to produce VA (does the SQL execute?) and EX (does it match the gold SQL after normalisation?) for both zero‑shot (k=0) and few‑shot (k>0) prompting.

It is standard practice to do this after the smoke tests because it separates “model/runtime works” from “method works,” and it produces the first reportable numbers for my project.

Uses my prompt pipeline (SCHEMA_SUMMARY, make_few_shot_prompt, generate_sql, enforce_minimal_projection).
Uses QueryRunner to compute VA safely (read‑only) and to log errors.
Saves per‑item outputs so error analysis can be done later

## Cell 15 — Batch evaluation (zero-shot vs few-shot) with VA/EX

This cell runs the first reportable baseline evaluation by looping over the ClassicModels benchmark and comparing **zero-shot** (k=0) to **few-shot prompting** (k=3) under controlled, deterministic decoding.

For each NLQ, we:
1) Build a schema-grounded prompt (schema + optional exemplars + NLQ) to support database-aware generation [1], [8], [9].
2) Generate SQL deterministically (`do_sample=False`) to eliminate sampling noise so metric changes reflect method changes rather than randomness [10], [18], [20].
3) Post-process the model output to extract a single executable `SELECT ...;` (and apply a minimal-projection heuristic for simple list-intent queries), consistent with execution-aware controls used in Text-to-SQL systems [13], [2], [10].
4) Execute the SQL via QueryRunner to compute **VA** (executability) and compute **EX** as a strict normalised string match baseline [10], [18], [19].

Interpretation note: **EX is intentionally strict** and will undercount semantically correct queries (aliasing, equivalent joins/subqueries, harmless rewrites). VA confirms executability; TS/result-equivalence is the next step for semantic correctness [18], [10].


In [None]:

# 15. Batch evaluation: zero-shot vs few-shot (VA/EX) — self-contained version

import json, random, re
from datetime import datetime, timezone

def normalize_sql(s: str) -> str:
    s = (s or "").strip()
    s = re.sub(r"\s+", " ", s)
    s = s.rstrip(";")
    return s.lower()

# ---------- Build schema summary (PK/name-like columns first) ----------
def build_schema_summary(engine: Engine, max_cols_per_table: int = 50) -> str:
    tables = list_tables(engine)
    chunks = []
    for t in tables:
        cols_df = get_table_columns(engine, t)

        priority_mask = (
            cols_df["COLUMN_KEY"].fillna("").isin(["PRI"])
            | cols_df["COLUMN_NAME"].astype(str).str.contains(r"name|id|line|code|number", case=False, regex=True)
        )
        priority = cols_df.loc[priority_mask, "COLUMN_NAME"].tolist()
        rest = [c for c in cols_df["COLUMN_NAME"].tolist() if c not in priority]
        cols = (priority + rest)[:max_cols_per_table]

        chunks.append(f"{t}({', '.join(cols)})")
    return "\n".join(chunks)

SCHEMA_SUMMARY = build_schema_summary(engine, max_cols_per_table=50)

# ---------- Prompt template ----------
SYSTEM_INSTRUCTIONS = """You are an expert data analyst writing MySQL queries.
Given the database schema and a natural language question, write a single SQL SELECT query.

Rules:
- Output ONLY SQL (no explanation, no markdown).
- Output exactly ONE statement, starting with SELECT.
- Select ONLY the columns needed to answer the question (minimal projection).
- Use only the tables/columns in the schema.
- Prefer explicit JOIN syntax.
- Use LIMIT when the question implies "top" or "first".
"""

def make_few_shot_prompt(schema: str, exemplars: list, nlq: str) -> list:
    msgs = [
        {"role": "system", "content": SYSTEM_INSTRUCTIONS},
        {"role": "user", "content": "Schema:\n" + schema},
    ]
    for ex in exemplars:
        msgs.append({"role": "user", "content": f"NLQ: {ex['nlq']}"})
        msgs.append({"role": "assistant", "content": ex["sql"].rstrip(";") + ";"})
    msgs.append({"role": "user", "content": f"NLQ: {nlq}"})
    return msgs

# ---------- Robust SQL extraction ----------
SQL_RE = re.compile(r"(?is)\bselect\b.*?(;|\Z)")

def generate_sql(messages, max_new_tokens: int = 128) -> str:
    input_ids = tok.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device)

    with torch.no_grad():
        out = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tok.eos_token_id,
        )

    gen_ids = out[0][input_ids.shape[-1]:]
    gen_text = tok.decode(gen_ids, skip_special_tokens=True).strip()

    m = SQL_RE.search(gen_text)
    if not m:
        return gen_text
    sql = m.group(0).strip()
    if not sql.endswith(";"):
        sql += ";"
    return sql

# ---------- Minimal-projection postprocess for "List all ..." ----------
LIST_ALL_RE = re.compile(r"(?is)^\s*list\s+all\s+")
SELECT_LIST_RE = re.compile(r"(?is)^\s*select\s+(.*?)\s+from\s+", re.DOTALL)

def enforce_minimal_projection(sql: str, nlq: str) -> str:
    if not sql or not nlq:
        return sql
    if not LIST_ALL_RE.search(nlq.strip()):
        return sql
    m = SELECT_LIST_RE.search(sql)
    if not m:
        return sql
    select_part = m.group(1).strip()
    if "*" in select_part:
        return sql
    first_expr = select_part.split(",")[0].strip()
    rebuilt = re.sub(SELECT_LIST_RE, lambda mm: f"SELECT {first_expr} FROM ", sql, count=1)
    return rebuilt

# ---------- Batch eval ----------
def now_utc_iso() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")

def eval_run(test_set, k: int, limit: int | None = 50, seed: int = 7, save_path: str | None = None):
    rng = random.Random(seed)
    items = test_set[:limit] if limit else test_set

    qr = QueryRunner(engine, max_rows=50)
    results = []

    for i, item in enumerate(items):
        nlq = item["nlq"]
        gold_sql = item["sql"]

        exemplars = rng.sample(test_set, k) if k > 0 else []
        messages = make_few_shot_prompt(SCHEMA_SUMMARY, exemplars, nlq)

        raw_sql = generate_sql(messages)
        pred_sql = enforce_minimal_projection(raw_sql, nlq)

        meta = qr.run(pred_sql, capture_df=False)
        ex = normalize_sql(pred_sql) == normalize_sql(gold_sql)

        results.append({
            "i": i,
            "nlq": nlq,
            "gold_sql": gold_sql,
            "raw_sql": raw_sql,
            "pred_sql": pred_sql,
            "va": bool(meta["success"]),
            "ex": bool(ex),
            "error": meta.get("error"),
        })

    va_rate = sum(r["va"] for r in results) / max(len(results), 1)
    ex_rate = sum(r["ex"] for r in results) / max(len(results), 1)
    print(f"k={k} | n={len(results)} | VA={va_rate:.3f} | EX={ex_rate:.3f}")

    if save_path:
        payload = {
            "timestamp": now_utc_iso(),
            "k": k,
            "seed": seed,
            "limit": limit,
            "n": len(results),
            "va_rate": va_rate,
            "ex_rate": ex_rate,
            "results": results,
        }
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2)
        print("Saved:", save_path)

    return results

# Run a quick sanity check first (50 items)
# zero_50 = eval_run(test_set, k=0, limit=20, seed=7, save_path="results_zero_shot_20.json")
# few_50  = eval_run(test_set, k=3, limit=20, seed=7, save_path="results_few_shot_k3_20.json")

# Full run (uncomment when ready)
zero_200 = eval_run(test_set, k=0, limit=None, seed=7, save_path="results_zero_shot_200.json")
few_200  = eval_run(test_set, k=3, limit=None, seed=7, save_path="results_few_shot_k3_200.json")


k=0 | n=200 | VA=0.810 | EX=0.000
Saved: results_zero_shot_200.json
k=3 | n=200 | VA=0.865 | EX=0.250
Saved: results_few_shot_k3_200.json
