# Training Set Validation
## Inputs
## Checks performed
## Output


## Summary
- Purpose: validate the prebuilt training set before QLoRA fine-tuning.
- Scope: checks leakage, SELECT-only safety, and SQL executability (VA).
- Outputs: validation artifacts under `results/training_set_validation/`.


In [None]:
# If running in Colab: install dependencies
try:
    import google.colab  # noqa: F401
    !pip -q install -r requirements.txt
except Exception:
    pass

import json
import os
from pathlib import Path

from sqlalchemy import text

from nl2sql.db import create_engine_with_connector


## 1) DB Connection


In [None]:
from getpass import getpass

INSTANCE_CONNECTION_NAME = os.getenv('INSTANCE_CONNECTION_NAME')
DB_USER = os.getenv('DB_USER')
DB_PASS = os.getenv('DB_PASS')
DB_NAME = os.getenv('DB_NAME') or 'classicmodels'

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input('Enter INSTANCE_CONNECTION_NAME: ').strip()
if not DB_USER:
    DB_USER = input('Enter DB_USER: ').strip()
if not DB_PASS:
    DB_PASS = getpass('Enter DB_PASS: ')

engine, connector = create_engine_with_connector(
    instance_connection_name=INSTANCE_CONNECTION_NAME,
    user=DB_USER,
    password=DB_PASS,
    db_name=DB_NAME,
)

with engine.connect() as conn:
    conn.execute(text('SELECT 1')).fetchone()
print('DB connection OK')


## 2) Load Data


In [None]:
test_path = Path('data/classicmodels_test_200.json')
train_path = Path('data/train/classicmodels_train_200.jsonl')

test_items = json.loads(test_path.read_text(encoding='utf-8'))
test_nlqs = {x['nlq'].strip() for x in test_items}

train_records = []
for line in train_path.read_text(encoding='utf-8').splitlines():
    line = line.strip()
    if not line:
        continue
    train_records.append(json.loads(line))

print('Test items:', len(test_items))
print('Train items:', len(train_records))

# Basic shape checks
for i, r in enumerate(train_records[:5]):
    assert 'nlq' in r and 'sql' in r, f'Missing keys at row {i}: {r}'


## 3) Leakage Checks


In [None]:
overlap = []
non_select = []
deduped = []
seen = set()

for idx, r in enumerate(train_records):
    nlq = str(r['nlq']).strip()
    sql = str(r['sql']).strip()

    if nlq in test_nlqs:
        overlap.append((idx, nlq))
        continue

    if not sql.lower().lstrip().startswith('select'):
        non_select.append((idx, nlq, sql[:120]))
        continue

    key = nlq
    if key in seen:
        continue
    seen.add(key)
    deduped.append({'nlq': nlq, 'sql': sql if sql.endswith(';') else sql + ';'})

print('Exact NLQ overlaps with test:', len(overlap))
print('Non-SELECT rows:', len(non_select))
print('After NLQ-dedup:', len(deduped))

if overlap[:5]:
    print('First overlaps:')
    for i, nlq in overlap[:5]:
        print('  -', i, nlq)

if non_select[:5]:
    print('First non-SELECT rows:')
    for i, nlq, sql_snip in non_select[:5]:
        print('  -', i, nlq, '->', sql_snip)


## 4) VA Check


In [None]:
failed = []

with engine.connect() as conn:
    for idx, r in enumerate(deduped):
        sql = r['sql']
        try:
            # Execute and fetch a tiny sample to force evaluation without pulling huge results.
            res = conn.execute(text(sql))
            res.fetchmany(1)
        except Exception as e:
            failed.append((idx, r['nlq'], sql, repr(e)))

print('Executable (VA=True):', len(deduped) - len(failed), '/', len(deduped))
print('Failed:', len(failed))

if failed[:5]:
    print('First failures:')
    for i, nlq, sql, err in failed[:5]:
        print('---')
        print('row:', i)
        print('nlq:', nlq)
        print('sql:', sql)
        print('err:', err)


## 5) Save Artifacts


In [None]:
from datetime import datetime, timezone
from pathlib import Path
import json
import shutil

report = {
    "timestamp_utc": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
    "test_items": len(test_items),
    "train_items_raw": len(train_records),
    "train_items_deduped": len(deduped),
    "overlap_count": len(overlap),
    "non_select_count": len(non_select),
    "va_pass_count": len(deduped) - len(failed),
    "va_fail_count": len(failed),
    "sample_failures": [
        {"row": i, "nlq": nlq, "sql": sql, "error": err}
        for i, nlq, sql, err in failed[:5]
    ],
}

local_dir = Path("results/training_set_validation")
local_dir.mkdir(parents=True, exist_ok=True)
local_report = local_dir / "validation_report.json"
local_report.write_text(json.dumps(report, indent=2), encoding="utf-8")
print(f"Saved local report: {local_report}")

# Optional Colab Drive backup
if Path("/content/drive/MyDrive").exists():
    stamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%SZ")
    drive_dir = Path("/content/drive/MyDrive/nl2sql_persistent_runs/training_set_validation") / stamp
    drive_dir.mkdir(parents=True, exist_ok=True)
    drive_report = drive_dir / "validation_report.json"
    shutil.copy2(local_report, drive_report)
    print(f"Saved Drive backup: {drive_report}")
else:
    print("Drive not detected at /content/drive/MyDrive (local save completed).")
