# Validate the QLoRA training set (ClassicModels)

This notebook **does not generate** training data. Instead, it validates that an existing JSONL training set executes cleanly on the ClassicModels database and does **not** overlap the fixed 200-item benchmark test set.

## Inputs
- Benchmark test set (fixed): `data/classicmodels_test_200.json`
- Training set (pre-generated): `data/train/classicmodels_train_200.jsonl`

## Checks performed
- **Train/test leakage**: rejects exact NLQ overlaps with the benchmark NLQs
- **SELECT-only**: rejects non-SELECT statements
- **Executability (VA)**: runs each SQL against the live DB to confirm it executes

## Output
- A small validation report printed in the notebook output


In [None]:
# If running in Colab: install dependencies
try:
    import google.colab  # noqa: F401
    !pip -q install -r requirements.txt
except Exception:
    pass

import json
import os
from pathlib import Path

from sqlalchemy import text

from nl2sql.db import create_engine_with_connector


## 1) Configure DB connection

This uses the same environment variables as the other notebooks:

- `INSTANCE_CONNECTION_NAME` (e.g. `project:region:instance`)
- `DB_USER`
- `DB_PASS`
- `DB_NAME` (usually `classicmodels`)

If they are not set, the notebook will prompt you.


In [None]:
from getpass import getpass

INSTANCE_CONNECTION_NAME = os.getenv('INSTANCE_CONNECTION_NAME')
DB_USER = os.getenv('DB_USER')
DB_PASS = os.getenv('DB_PASS')
DB_NAME = os.getenv('DB_NAME') or 'classicmodels'

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input('Enter INSTANCE_CONNECTION_NAME: ').strip()
if not DB_USER:
    DB_USER = input('Enter DB_USER: ').strip()
if not DB_PASS:
    DB_PASS = getpass('Enter DB_PASS: ')

engine, connector = create_engine_with_connector(
    instance_connection_name=INSTANCE_CONNECTION_NAME,
    user=DB_USER,
    password=DB_PASS,
    db_name=DB_NAME,
)

with engine.connect() as conn:
    conn.execute(text('SELECT 1')).fetchone()
print('DB connection OK')


## 2) Load benchmark + training set

The benchmark is fixed at 200 items. The training set must be separate (no exact NLQ overlap).


In [None]:
test_path = Path('data/classicmodels_test_200.json')
train_path = Path('data/train/classicmodels_train_200.jsonl')

test_items = json.loads(test_path.read_text(encoding='utf-8'))
test_nlqs = {x['nlq'].strip() for x in test_items}

train_records = []
for line in train_path.read_text(encoding='utf-8').splitlines():
    line = line.strip()
    if not line:
        continue
    train_records.append(json.loads(line))

print('Test items:', len(test_items))
print('Train items:', len(train_records))

# Basic shape checks
for i, r in enumerate(train_records[:5]):
    assert 'nlq' in r and 'sql' in r, f'Missing keys at row {i}: {r}'


## 3) Leakage + safety checks

- **Leakage** is defined here as an *exact NLQ string match* between train and test.
- **Safety** is defined as `SELECT`-only.


In [None]:
overlap = []
non_select = []
deduped = []
seen = set()

for idx, r in enumerate(train_records):
    nlq = str(r['nlq']).strip()
    sql = str(r['sql']).strip()

    if nlq in test_nlqs:
        overlap.append((idx, nlq))
        continue

    if not sql.lower().lstrip().startswith('select'):
        non_select.append((idx, nlq, sql[:120]))
        continue

    key = nlq
    if key in seen:
        continue
    seen.add(key)
    deduped.append({'nlq': nlq, 'sql': sql if sql.endswith(';') else sql + ';'})

print('Exact NLQ overlaps with test:', len(overlap))
print('Non-SELECT rows:', len(non_select))
print('After NLQ-dedup:', len(deduped))

if overlap[:5]:
    print('First overlaps:')
    for i, nlq in overlap[:5]:
        print('  -', i, nlq)

if non_select[:5]:
    print('First non-SELECT rows:')
    for i, nlq, sql_snip in non_select[:5]:
        print('  -', i, nlq, '->', sql_snip)


## 4) Executability validation (VA)

This runs each SQL query against the database. A query is considered executable if it runs without raising an exception.


In [None]:
failed = []

with engine.connect() as conn:
    for idx, r in enumerate(deduped):
        sql = r['sql']
        try:
            # Execute and fetch a tiny sample to force evaluation without pulling huge results.
            res = conn.execute(text(sql))
            res.fetchmany(1)
        except Exception as e:
            failed.append((idx, r['nlq'], sql, repr(e)))

print('Executable (VA=True):', len(deduped) - len(failed), '/', len(deduped))
print('Failed:', len(failed))

if failed[:5]:
    print('First failures:')
    for i, nlq, sql, err in failed[:5]:
        print('---')
        print('row:', i)
        print('nlq:', nlq)
        print('sql:', sql)
        print('err:', err)


## 5) Next step

If failures are **0**, proceed to `notebooks/05_qlora_train_eval.ipynb` to fine-tune + evaluate.

If there are failures, fix the problematic rows in `data/train/classicmodels_train_200.jsonl` and re-run this notebook.
