
# NL-to-SQL scaffold (classicmodels)

What this notebook does:
- Auth to GCP
- Safe Cloud SQL connection via connector + SQLAlchemy
- Schema helpers + QueryRunner tool
- Smoke tests and dataset validator
- Base Llama-3-8B load (pre-QLoRA placeholder)


In [None]:
# Auth to Google Cloud
from google.colab import auth
auth.authenticate_user()

## Project context
Swap to env var in production if you don't want to hardcode project_id.

In [None]:
import os
project_id = "modified-enigma-476414-h9"  # replace with env var in production
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id


## Installs
Pin these in a requirements cell/file for real runs.

In [None]:
import sys
!{sys.executable} -m pip install --upgrade pip
!{sys.executable} -m pip install "cloud-sql-python-connector[pymysql]" SQLAlchemy==2.0.7 pymysql cryptography==41.0.0 --force-reinstall --no-cache-dir
!{sys.executable} -m pip install accelerate
!{sys.executable} -m pip install bitsandbytes
!{sys.executable} -m pip install peft
!{sys.executable} -m pip install transformers
!{sys.executable} -m pip install datasets
!{sys.executable} -m pip install trl


In [None]:
# Imports and logger
import os
import logging
from google.cloud.sql.connector import Connector
import sqlalchemy
from sqlalchemy import text
import pymysql
from typing import Optional

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("nl2sql_db")


## Connection params
Env first, prompt fallback during dev.

In [None]:
from getpass import getpass

# Set these environment variables in Colab using:
#   %env DB_USER=... %env DB_NAME=... etc (or use secrets manager)
INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME", "classicmodels")

# Fallback interactive prompt if variables missing
if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")


## Connector + engine setup

In [None]:
from google.api_core import retry
from sqlalchemy.engine import Engine
import time
from contextlib import contextmanager

connector = Connector()

def getconn():
    """SQLAlchemy creator hook using the Cloud SQL connector."""
    return connector.connect(
        INSTANCE_CONNECTION_NAME,
        "pymysql",
        user=DB_USER,
        password=DB_PASS,
        db=DB_NAME,
    )

engine: Engine = sqlalchemy.create_engine("mysql+pymysql://", creator=getconn, future=True)

@contextmanager
def safe_connection(engine):
    """Yield a connection and clean up after use."""
    conn = None
    try:
        conn = engine.connect()
        yield conn
    finally:
        if conn:
            conn.close()


## Schema exploration helpers

In [None]:
import pandas as pd

def list_tables(engine) -> list:
    """Return a list of table names."""
    with safe_connection(engine) as conn:
        result = conn.execute(text("SHOW TABLES;")).fetchall()
    return [r[0] for r in result]

def get_table_columns(engine, table_name: str) -> pd.DataFrame:
    """Return a DataFrame of columns."""
    query = text("""
        SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY
        FROM INFORMATION_SCHEMA.COLUMNS
        WHERE TABLE_SCHEMA = :db AND TABLE_NAME = :table
        ORDER BY ORDINAL_POSITION
    """)
    with safe_connection(engine) as conn:
        df = pd.read_sql(query, conn, params={"db": DB_NAME, "table": table_name})
    return df


## Smoke tests

In [None]:
def fetch_sample_customers(limit: int = 10):
    q = text("SELECT customerNumber, customerName, country FROM customers LIMIT :limit;")
    with safe_connection(engine) as conn:
        df = pd.read_sql(q, conn, params={"limit": limit})
    return df

try:
    tables = list_tables(engine)
    logger.info("Tables in classicmodels: %s", tables)
    sample_df = fetch_sample_customers(5)
    display(sample_df)
except Exception as e:
    logger.exception("Smoke test failed: %s", e)


## QueryRunner (read-only tool)

In [None]:
import json
from datetime import datetime, timezone
from typing import Any, Dict

class QueryExecutionError(Exception):
    pass

def now_utc_iso() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")

class QueryRunner:
    """
    Execute generated SQL safely against the engine, capture results and metadata,
    and keep a history suitable for evaluation and error analysis.
    """
    def __init__(self, engine, max_rows: int = 1000, forbidden_tokens=None):
        self.engine = engine
        self.max_rows = max_rows
        self.history = []
        self.forbidden_tokens = forbidden_tokens or ["drop ", "delete ", "truncate ", "alter ", "create ", "update ", "insert "]

    def _safety_check(self, sql: str) -> None:
        lowered = (sql or "").strip().lower()
        if not lowered:
            raise QueryExecutionError("Empty SQL string")
        for token in self.forbidden_tokens:
            if token in lowered:
                raise QueryExecutionError(f"Destructive SQL token detected: {token.strip()}")

    def run(self, sql: str, params: Optional[Dict[str, Any]] = None, capture_df: bool = True) -> Dict[str, Any]:
        entry = {
            "sql": sql,
            "params": params,
            "timestamp": now_utc_iso(),
            "success": False,
            "rowcount": 0,
            "exec_time_s": None,
            "error": None,
            "columns": None,
            "result_preview": None,
        }
        try:
            self._safety_check(sql)
            start = datetime.now(timezone.utc)
            with safe_connection(self.engine) as conn:
                result = conn.execute(sqlalchemy.text(sql), params or {})
                rows = result.fetchall()
                cols = list(result.keys())
            end = datetime.now(timezone.utc)
            exec_time = (end - start).total_seconds()
            df = None
            if capture_df:
                df = pd.DataFrame(rows, columns=cols)
                if len(df) > self.max_rows:
                    df = df.iloc[: self.max_rows]
            entry.update({
                "success": True,
                "rowcount": min(len(rows), self.max_rows),
                "exec_time_s": exec_time,
                "columns": cols,
                "result_preview": df
            })
        except Exception as e:
            entry.update({
                "error": str(e),
                "success": False
            })
        finally:
            self.history.append(entry)
        return entry

    def last(self):
        return self.history[-1] if self.history else None

    def save_history(self, path: str):
        serializable = []
        for h in self.history:
            s = {k: v for k, v in h.items() if k != "result_preview"}
            serializable.append(s)
        with open(path, "w", encoding="utf-8") as f:
            json.dump(serializable, f, indent=2, default=str)


## QueryRunner quick test

In [None]:
qr = QueryRunner(engine, max_rows=200)
test_sql = "SELECT customerNumber, customerName, country FROM customers LIMIT 10;"
meta = qr.run(test_sql)
print("Success:", meta["success"])
if meta["success"]:
    display(meta["result_preview"])
else:
    print("Error:", meta["error"])

# List and display schema
for table_name in list_tables(engine):
    print(f"
Schema for table: {table_name}")
    df_columns = get_table_columns(engine, table_name)
    display(df_columns)


## Dataset validation helper
Run the static classicmodels test set against the live DB.

In [None]:
def validate_test_set(path: str = "data/classicmodels_test_200.json", limit: Optional[int] = None):
    import json
    with open(path, "r", encoding="utf-8") as f:
        items = json.load(f)
    if limit:
        items = items[:limit]

    qr = QueryRunner(engine, max_rows=200)
    successes = []
    failures = []
    for idx, item in enumerate(items):
        meta = qr.run(item["sql"], capture_df=False)
        if meta["success"]:
            successes.append(idx)
        else:
            failures.append({
                "index": idx,
                "nlq": item.get("nlq"),
                "sql": item.get("sql"),
                "error": meta["error"],
            })
    print(f"Ran {len(items)} queries. Success: {len(successes)}. Failures: {len(failures)}.")
    if failures:
        print("Failures (first 5):")
        for f in failures[:5]:
            print(f)
    return successes, failures


## Tiny starter NLQ-SQL set (for quick checks)
Main test set lives in data/classicmodels_test_200.json.

## Load static test set
Use the fixed 200-sample NLQ-SQL pairs from data/classicmodels_test_200.json.

In [None]:
import json
with open('data/classicmodels_test_200.json', 'r', encoding='utf-8') as f:
    test_set = json.load(f)
print(f"Loaded {len(test_set)} test items from data/classicmodels_test_200.json")


## Load base model/tokenizer (pre-QLoRA placeholder)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=True
)

print(f"Tokenizer and model '{model_id}' loaded successfully.")
