# Build a DB-validated training set (LLM-assisted)

This notebook generates NLQâ†’SQL training pairs using an LLM, then **validates every SQL** against the live ClassicModels database (VA must be True) before saving.

Why this exists:
- QLoRA needs a **training** dataset that is **separate** from the 200-item benchmark (`data/classicmodels_test_200.json`).
- If you generate data with an LLM, you must be strict: only keep examples that execute, are SELECT-only, and do not overlap the test NLQs.

Output:
- `data/train/classicmodels_train_200.jsonl` (JSON Lines: `{ "nlq": ..., "sql": ... }`)


In [None]:
import os, sys, shutil
from pathlib import Path

# If opened directly in Colab, clone the repo first
if Path("data/classicmodels_test_200.json").exists() is False and Path("/content").exists():
    repo_dir = Path("/content/NLtoSQL")
    if repo_dir.exists():
        shutil.rmtree(repo_dir)
    !git clone https://github.com/MacKenzieOBrian/NLtoSQL.git "{repo_dir}"
    os.chdir(repo_dir)

sys.path.insert(0, os.getcwd())
print("cwd:", os.getcwd())


## 0) Install dependencies (Colab)

Install pinned deps and restart runtime if Colab asks.


In [None]:
try:
    import google.colab  # noqa: F401
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    !pip -q install -r requirements.txt
else:
    print("Not in Colab; ensure requirements are installed.")


## 1) Auth (GCP + Hugging Face)


In [None]:
# GCP auth
try:
    from google.colab import auth
except ModuleNotFoundError:
    auth = None
if auth:
    auth.authenticate_user()
else:
    print("Not running in Colab; ensure ADC/service account auth is configured.")

# Hugging Face auth
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
if hf_token:
    os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
    print("Using HF token from env")
else:
    try:
        from huggingface_hub import notebook_login
        notebook_login()
    except Exception as e:
        print("HF auth not configured:", e)


## 2) Load test set (to avoid overlap)


In [None]:
import json

test_set = json.loads(open("data/classicmodels_test_200.json", "r", encoding="utf-8").read())
test_nlqs = {x["nlq"].strip() for x in test_set}
print("Loaded test NLQs:", len(test_nlqs))


## 3) DB engine + schema summary


In [None]:
from getpass import getpass

from nl2sql.db import create_engine_with_connector
from nl2sql.query_runner import QueryRunner
from nl2sql.schema import build_schema_summary

INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME", "classicmodels")

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

engine, connector = create_engine_with_connector(
    instance_connection_name=INSTANCE_CONNECTION_NAME,
    user=DB_USER,
    password=DB_PASS,
    db_name=DB_NAME,
)

SCHEMA_SUMMARY = build_schema_summary(engine, db_name=DB_NAME, max_cols_per_table=50)
qr = QueryRunner(engine, max_rows=50)
print("Schema summary length:", len(SCHEMA_SUMMARY))


## 4) Load the LLM

You can generate training pairs with the same base model used in baseline (or swap to a different generator model).


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

tok = AutoTokenizer.from_pretrained(MODEL_ID, token=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=True,
)

print("Model loaded on:", model.device)


## 5) Generate + validate pairs

Policy (strict):
- Keep only **single SELECT** statements.
- Must execute successfully (VA=True) on the live DB.
- Discard if NLQ overlaps the benchmark NLQs.
- Deduplicate by normalized SQL and NLQ.


In [None]:
import json
import random
import re
from nl2sql.llm import extract_first_select
from nl2sql.postprocess import normalize_sql

JSON_RE = re.compile(r"\{.*\}", re.DOTALL)

SYSTEM = """You create training data for NL-to-SQL over the ClassicModels MySQL database.
Return ONE JSON object with exactly these keys: nlq, sql.

Rules:
- sql must be a single MySQL SELECT statement.
- No comments, no markdown, no extra keys.
- Use only tables/columns present in the schema.
- Make the NLQ natural and specific.
"""

def propose_one(rng: random.Random) -> dict:
    messages = [
        {"role": "system", "content": SYSTEM},
        {"role": "user", "content": "Schema:\n" + SCHEMA_SUMMARY},
        {
            "role": "user",
            "content": "Generate a new, non-trivial training example (joins/aggregations/filters encouraged).",
        },
    ]

    input_ids = tok.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device)

    with torch.no_grad():
        out = model.generate(
            input_ids,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tok.eos_token_id,
        )

    gen_ids = out[0][input_ids.shape[-1] :]
    text = tok.decode(gen_ids, skip_special_tokens=True).strip()

    m = JSON_RE.search(text)
    if not m:
        raise ValueError("No JSON object found")
    obj = json.loads(m.group(0))
    if set(obj.keys()) != {"nlq", "sql"}:
        raise ValueError(f"Bad keys: {sorted(obj.keys())}")

    nlq = str(obj["nlq"]).strip()
    sql = str(obj["sql"]).strip()
    sql = extract_first_select(sql) or extract_first_select(text) or sql
    obj = {"nlq": nlq, "sql": sql}
    return obj

def is_acceptable(candidate: dict) -> tuple[bool, str | None]:
    nlq = candidate.get("nlq", "").strip()
    sql = candidate.get("sql", "").strip()

    if not nlq or not sql:
        return False, "empty"
    if nlq in test_nlqs:
        return False, "nlq-overlaps-test"

    sql_norm = normalize_sql(sql)
    if not sql_norm.startswith("select "):
        return False, "not-select"

    meta = qr.run(sql, capture_df=False)
    if not meta.success:
        return False, f"va-false: {meta.error}"
    return True, None

TARGET_N = 200
MAX_ATTEMPTS = 1200
SEED = 7

rng = random.Random(SEED)
seen_sql = set()
seen_nlq = set()
accepted = []

for attempt in range(1, MAX_ATTEMPTS + 1):
    try:
        cand = propose_one(rng)
    except Exception as e:
        if attempt % 50 == 0:
            print("attempt", attempt, "parse-fail", str(e)[:120])
        continue

    ok, reason = is_acceptable(cand)
    if not ok:
        if attempt % 50 == 0:
            print("attempt", attempt, "reject", reason)
        continue

    sql_key = normalize_sql(cand["sql"]) 
    nlq_key = cand["nlq"].strip().lower()
    if sql_key in seen_sql or nlq_key in seen_nlq:
        continue

    seen_sql.add(sql_key)
    seen_nlq.add(nlq_key)
    accepted.append(cand)

    if len(accepted) % 25 == 0:
        print("accepted", len(accepted), "/", TARGET_N)
    if len(accepted) >= TARGET_N:
        break

print("Final accepted:", len(accepted))
accepted[:3]


## 6) Save to `data/train/`


In [None]:
from pathlib import Path
import json

out_dir = Path("data/train")
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / "classicmodels_train_200.jsonl"

with out_path.open("w", encoding="utf-8") as f:
    for item in accepted:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

print("Wrote:", out_path, "lines:", len(accepted))
