# NL → SQL (PostgreSQL) prompt tuning + evaluation\n
\n
This notebook does not "train" the base LLM. Instead it helps you iterate on:\n
- schema formatting\n
- prompt rules\n
- few-shot examples (if you want)\n
- evaluation set (NL question → expected SQL pattern)\n

In [None]:
import os\n
import sys\n
from dotenv import load_dotenv\n
\n
ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))\n
SRC = os.path.join(ROOT, "src")\n
if SRC not in sys.path:\n
    sys.path.insert(0, SRC)\n
\n
load_dotenv(os.path.join(ROOT, ".env"))\n
\n
from nl2sql.db import PostgresDB\n
from nl2sql.agent import generate_sql\n
from nl2sql.config import DEFAULT_GEMINI_MODEL\n
from nl2sql.sql_safety import validate_readonly_sql\n
\n
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")\n
MODEL = DEFAULT_GEMINI_MODEL\n
DATABASE_URL = os.getenv("DATABASE_URL", "")\n
\n
assert GEMINI_API_KEY, "Set GEMINI_API_KEY in .env"\n
assert DATABASE_URL, "Set DATABASE_URL in .env"\n
\n
db = PostgresDB(DATABASE_URL)\n
schema_text = db.fetch_schema()\n
print(schema_text[:1500])

## Define an evaluation set\n
\n
Put your common business questions here. For each item, add a simple check like: must contain a table name, must contain a JOIN, must contain GROUP BY, etc.

In [None]:
EVAL = [\n
    {\n
        "question": "List top 10 customers by total order amount",\n
        "must_contain": ["limit", "order by"],\n
        "must_not_contain": ["insert", "update", "delete", "drop"],\n
    },\n
    {\n
        "question": "Show total revenue per month for 2024",\n
        "must_contain": ["group by"],\n
        "must_not_contain": ["insert", "update", "delete", "drop"],\n
    },\n
]\n
\n
len(EVAL)

## Run evaluation

In [1]:
def check(sql: str, item: dict) -> dict:\n
    s = sql.lower()\n
    ok = True\n
    missing = [x for x in item.get("must_contain", []) if x.lower() not in s]\n
    present_bad = [x for x in item.get("must_not_contain", []) if x.lower() in s]\n
    if missing or present_bad:\n
        ok = False\n
    return {"ok": ok, "missing": missing, "present_bad": present_bad}\n
\n
results = []\n
for item in EVAL:\n
    sql = generate_sql(\n
        provider="gemini",\n
        api_key=GEMINI_API_KEY,\n
        model=MODEL,\n
        schema_text=schema_text,\n
        question=item["question"],\n
        chat_history=None,\n
    )\n
    sql = validate_readonly_sql(sql)\n
    results.append({"question": item["question"], "sql": sql, **check(sql, item)})\n
\n
results

SyntaxError: unexpected character after line continuation character (2470767212.py, line 1)

## Next: iterate on the prompt\n
\n
Edit `src/nl2sql/agent.py` rules and re-run the eval until your patterns pass for your DB schema.