# no-model nl-to-sql demo

primary path only: raw extraction -> execution -> va/em/ex.

links:
- code: `/Users/mackenzieobrian/MacDoc/Dissertation/nl2sql/core/prompting.py`
- code: `/Users/mackenzieobrian/MacDoc/Dissertation/nl2sql/core/llm.py`
- code: `/Users/mackenzieobrian/MacDoc/Dissertation/nl2sql/evaluation/eval.py`
- docs: [sqlalchemy](https://docs.sqlalchemy.org/en/20/), [sqlite](https://www.sqlite.org/lang_select.html)
- literature: [spider](https://aclanthology.org/D18-1425/), [test suite accuracy](https://aclanthology.org/2020.emnlp-main.29/)


In [None]:
from __future__ import annotations

import json
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import sqlalchemy
from sqlalchemy import text

from nl2sql.core.prompting import SYSTEM_INSTRUCTIONS, make_few_shot_messages
from nl2sql.core.llm import debug_extract_first_select
from nl2sql.core.query_runner import QueryRunner
from nl2sql.evaluation.eval import execution_accuracy

print("imports ready. no model loaded.")


## 1) pick a benchmark item and define the local demo question

In [None]:
def _load_test_set() -> list[dict]:
    candidates = [
        Path("data/classicmodels_test_200.json"),
        Path("../data/classicmodels_test_200.json"),
    ]
    for p in candidates:
        if p.exists():
            return json.loads(p.read_text(encoding="utf-8"))
    raise FileNotFoundError("Could not locate classicmodels_test_200.json")


benchmark = _load_test_set()
benchmark_item = benchmark[0]

demo_nlq = "List all customer names in France"
demo_gold_sql = "SELECT customerName FROM customers WHERE country = 'France';"

print("Benchmark sample NLQ:", benchmark_item["nlq"])
print("Benchmark sample SQL:", benchmark_item["sql"])
print("\nLocal demo NLQ:", demo_nlq)
print("Local demo gold SQL:", demo_gold_sql)


## 2) build schema context and prompt messages

In [None]:
SCHEMA_SUMMARY_DEMO = (
    # demo schema is tiny on purpose so each stage is easy to explain.
    "Table customers (\n"
    "  customerNumber INT,\n"
    "  customerName TEXT,\n"
    "  contactLastName TEXT,\n"
    "  country TEXT,\n"
    "  creditLimit REAL\n"
    ")\n"
    "Table orders (\n"
    "  orderNumber INT,\n"
    "  customerNumber INT,\n"
    "  orderDate TEXT,\n"
    "  status TEXT\n"
    ")"
)

exemplars = [
    {
        "nlq": "List all customer names in Germany",
        "sql": "SELECT customerName FROM customers WHERE country = 'Germany';",
    },
    {
        "nlq": "Show customer names and credit limit for customers in France",
        "sql": "SELECT customerName, creditLimit FROM customers WHERE country = 'France';",
    },
]

messages = make_few_shot_messages(
    schema=SCHEMA_SUMMARY_DEMO,
    exemplars=exemplars,
    nlq=demo_nlq,
)

print("system prompt starts with:")
print(SYSTEM_INSTRUCTIONS.splitlines()[0])
print("\nmessage count:", len(messages))
for i, m in enumerate(messages, start=1):
    snippet = str(m["content"]).replace("\n", " ")[:140]
    print(f"[{i}] {m['role']}: {snippet}")


## 3) tiny em normalizer (text compare only)

In [None]:
def simple_sql_norm(sql: str) -> str:
    # keep this tiny: lowercase + collapse whitespace + strip trailing semicolon
    return " ".join((sql or "").strip().rstrip(";").split()).lower()

print("EM normalizer ready:", simple_sql_norm("SELECT customerName FROM customers;"))


## 4) create a tiny local sqlite db for execution checks

In [None]:
engine = sqlalchemy.create_engine("sqlite+pysqlite:///:memory:", future=True)

with engine.begin() as conn:
    conn.execute(text("""
        CREATE TABLE customers (
            customerNumber INTEGER PRIMARY KEY,
            customerName TEXT,
            contactLastName TEXT,
            country TEXT,
            creditLimit REAL
        )
    """))

    conn.execute(text("""
        CREATE TABLE orders (
            orderNumber INTEGER PRIMARY KEY,
            customerNumber INTEGER,
            orderDate TEXT,
            status TEXT
        )
    """))

    conn.execute(text("""
        INSERT INTO customers (customerNumber, customerName, contactLastName, country, creditLimit) VALUES
        (103, 'Atelier graphique', 'Schmitt', 'France', 21000.00),
        (112, 'Signal Gift Stores', 'King', 'USA', 71800.00),
        (119, 'La Rochelle Gifts', 'Labrune', 'France', 118200.00),
        (121, 'Baane Mini Imports', 'Petersen', 'Denmark', 81700.00)
    """))

with engine.connect() as conn:
    preview = pd.read_sql(
        text("SELECT customerNumber, customerName, country FROM customers ORDER BY customerNumber"),
        conn,
    )

print("Local DB preview:")
display(preview)


## 5) simulate noisy model text and inspect extraction

In [None]:
raw_generations = [
    """
I think the answer is:
SELECT customerName, customerNumber
FROM customer
WHERE country = 'France'
ORDER BY customerName DESC
LIMIT 5;
""",
    """
Here is SQL:
```sql
SELECT customerName, customerNumber
FROM customers
WHERE country = 'France'
ORDER BY customerName DESC
LIMIT 5;
```
""",
]

debug_rows: list[dict] = []
for i, raw in enumerate(raw_generations, start=1):
    print("\n" + "=" * 90)
    print(f"raw candidate {i}")
    print(raw.strip())

    extract_debug = debug_extract_first_select(raw)
    extracted_sql = extract_debug.get("selected_sql") or raw

    for cand in extract_debug.get("candidates", []):
        debug_rows.append(
            {
                "candidate": i,
                "accepted": cand.get("accepted"),
                "reject_reason": cand.get("reject_reason"),
                "from_target": cand.get("from_target"),
                "sql": cand.get("candidate_sql"),
            }
        )

    print("\nselected sql (no cleaning):")
    print(extracted_sql)

if debug_rows:
    print("\nextraction debug table:")
    display(pd.DataFrame(debug_rows))


## 6) evaluate candidates in a mini retry loop (va/em/ex)

In [None]:
qr = QueryRunner(engine, max_rows=50)
attempt_rows: list[dict] = []

for i, raw in enumerate(raw_generations, start=1):
    extract_debug = debug_extract_first_select(raw)
    pred_sql = extract_debug.get("selected_sql") or raw

    va_meta = qr.run(pred_sql, capture_df=False)
    em = simple_sql_norm(pred_sql) == simple_sql_norm(demo_gold_sql)
    ex, ex_pred_err, ex_gold_err = execution_accuracy(
        engine=engine,
        pred_sql=pred_sql,
        gold_sql=demo_gold_sql,
    )

    attempt_rows.append(
        {
            "attempt": i,
            "pred_sql": pred_sql,
            "va": int(bool(va_meta.success)),
            "em": int(bool(em)),
            "ex": int(bool(ex)),
            "error": va_meta.error or ex_pred_err,
            "gold_error": ex_gold_err,
        }
    )

    if bool(va_meta.success) and bool(ex):
        break

report_df = pd.DataFrame(attempt_rows)
display(report_df)

if not report_df.empty:
    score_cols = ["va", "em", "ex"]
    ax = report_df.set_index("attempt")[score_cols].plot(kind="bar", figsize=(7, 3.5), rot=0)
    ax.set_ylim(0, 1.05)
    ax.set_title("attempt metrics (raw path)")
    ax.set_ylabel("score")
    ax.grid(axis="y", alpha=0.2)
    plt.tight_layout()
    plt.show()

    final_row = report_df.iloc[-1].to_dict()
    print("selected final attempt:", final_row["attempt"])
    print("selected sql:", final_row["pred_sql"])


## 7) explicit-field example in raw mode

In [None]:
explicit_nlq = "List contact last name, customer name, and customer number for customers in France"
explicit_gold_sql = "SELECT contactLastName, customerName, customerNumber FROM customers WHERE country = 'France';"

explicit_raw = """
SELECT customerName, creditLimit, customerNumber, contactLastName
FROM customers
WHERE country = 'France'
ORDER BY customerName
"""

explicit_extract = debug_extract_first_select(explicit_raw)
explicit_pred_sql = explicit_extract.get("selected_sql") or explicit_raw

print("nlq:", explicit_nlq)
print("\nraw selected sql:")
print(explicit_pred_sql)

explicit_qr = QueryRunner(engine, max_rows=50)
explicit_meta = explicit_qr.run(explicit_pred_sql, capture_df=False)
explicit_em = simple_sql_norm(explicit_pred_sql) == simple_sql_norm(explicit_gold_sql)
explicit_ex, explicit_pred_err, explicit_gold_err = execution_accuracy(
    engine=engine,
    pred_sql=explicit_pred_sql,
    gold_sql=explicit_gold_sql,
)

metrics_row = {
    "va": int(bool(explicit_meta.success)),
    "em": int(bool(explicit_em)),
    "ex": int(bool(explicit_ex)),
    "pred_error": explicit_meta.error or explicit_pred_err,
    "gold_error": explicit_gold_err,
}
print("\nmetrics in raw mode:")
print(metrics_row)

display(pd.DataFrame([metrics_row]))


## done

this demo is aligned to the primary path: raw generation behavior, no cleaning layer.