# NL-to-SQL scaffold (classicmodels)

What this notebook does:
- Auth to GCP
- Safe Cloud SQL connection (classicmodels) via connector + SQLAlchemy
- Schema helpers + QueryRunner tool
- Smoke tests and dataset validator
- Base Llama-3-8B load (pre-QLoRA placeholder)

In [24]:
# Auth to Google Cloud (Colab) or skip gracefully elsewhere
try:
    from google.colab import auth
except ModuleNotFoundError:
    auth = None
if auth:
    auth.authenticate_user()
else:
    print("Not running in Colab; ensure GCP auth via gcloud/ADC or service account if needed.")

## Project context
Swap to env var in production if you don't want to hardcode project_id.

In [25]:
import os
project_id = "modified-enigma-476414-h9"  # replace with env var in production
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id

## Installs
Pin these in a requirements cell/file for real runs.

In [12]:
import sys
# Prefer installing from pinned requirements.txt for reproducibility
%cd NLtoSQL
!{sys.executable} -m pip install --upgrade pip
!{sys.executable} -m pip install -r requirements.txt


/content/NLtoSQL


## Imports and logger

In [26]:
import os
import logging
from google.cloud.sql.connector import Connector
import sqlalchemy
from sqlalchemy import text
import pymysql
from typing import Optional

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("nl2sql_db")

## Connection params
Env first, prompt fallback during dev.

In [27]:
from getpass import getpass

INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME", "classicmodels")

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()
if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()
if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

Enter INSTANCE_CONNECTION_NAME: modified-enigma-476414-h9:europe-west2:classicmodels
Enter DB_USER: root
Enter DB_PASS: ··········


## Connector + engine setup

In [28]:
from google.api_core import retry
from sqlalchemy.engine import Engine
import time
from contextlib import contextmanager

connector = Connector()

def getconn():
    """SQLAlchemy creator hook using the Cloud SQL connector."""
    return connector.connect(
        INSTANCE_CONNECTION_NAME,
        "pymysql",
        user=DB_USER,
        password=DB_PASS,
        db=DB_NAME,
    )

engine: Engine = sqlalchemy.create_engine("mysql+pymysql://", creator=getconn, future=True)

@contextmanager
def safe_connection(engine):
    """Yield a connection and clean up after use."""
    conn = None
    try:
        conn = engine.connect()
        yield conn
    finally:
        if conn:
            conn.close()

## Schema exploration helpers

In [29]:
import pandas as pd

def list_tables(engine) -> list:
    """Return a list of table names."""
    with safe_connection(engine) as conn:
        result = conn.execute(text("SHOW TABLES;")).fetchall()
    return [r[0] for r in result]

def get_table_columns(engine, table_name: str) -> pd.DataFrame:
    """Return a DataFrame of columns."""
    query = text("""
        SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY
        FROM INFORMATION_SCHEMA.COLUMNS
        WHERE TABLE_SCHEMA = :db AND TABLE_NAME = :table
        ORDER BY ORDINAL_POSITION
    """)
    with safe_connection(engine) as conn:
        df = pd.read_sql(query, conn, params={"db": DB_NAME, "table": table_name})
    return df

ERROR:asyncio:Unclosed client session
client_session: <aiohttp.client.ClientSession object at 0x7d895da20320>


In [30]:
!gcloud auth application-default login



You are running on a Google Compute Engine virtual machine.
The service credentials associated with this virtual machine
will automatically be used by Application Default
Credentials, so it is not necessary to use this command.

If you decide to proceed anyway, your user credentials may be visible
to others with access to this virtual machine. Are you sure you want
to authenticate with your personal account?

Do you want to continue (Y/n)?  Y

Go to the following link in your browser, and complete the sign-in prompts:

    https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&redirect_uri=https%3A%2F%2Fsdk.cloud.google.com%2Fapplicationdefaultauthcode.html&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login&state=FuQcPzojcZxc5c2DG9LvOo3YEuo7rS&prompt=consent&token_

In [22]:
!gcloud auth application-default set-quota-project modified-enigma-476414-h9



Credentials saved to file: [/content/.config/application_default_credentials.json]

These credentials will be used by any library that requests Application Default Credentials (ADC).

Quota project "modified-enigma-476414-h9" was added to ADC which can be used by Google client libraries for billing and quota. Note that some services may still bill the project owning the resource.


## Smoke tests

In [31]:
def fetch_sample_customers(limit: int = 10):
    q = text("SELECT customerNumber, customerName, country FROM customers LIMIT :limit;")
    with safe_connection(engine) as conn:
        df = pd.read_sql(q, conn, params={"limit": limit})
    return df

try:
    tables = list_tables(engine)
    logger.info("Tables in classicmodels: %s", tables)
    sample_df = fetch_sample_customers(5)
    display(sample_df)
except Exception as e:
    logger.exception("Smoke test failed: %s", e)

Unnamed: 0,customerNumber,customerName,country
0,103,Atelier graphique,France
1,112,Signal Gift Stores,USA
2,114,"Australian Collectors, Co.",Australia
3,119,La Rochelle Gifts,France
4,121,Baane Mini Imports,Norway


## QueryRunner (read-only tool)

In [32]:
import json
from datetime import datetime, timezone
from typing import Any, Dict

class QueryExecutionError(Exception):
    pass

def now_utc_iso() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")

class QueryRunner:
    """
    Execute generated SQL safely against the engine, capture results and metadata,
    and keep a history suitable for evaluation and error analysis.
    """
    def __init__(self, engine, max_rows: int = 1000, forbidden_tokens=None):
        self.engine = engine
        self.max_rows = max_rows
        self.history = []
        self.forbidden_tokens = forbidden_tokens or ["drop ", "delete ", "truncate ", "alter ", "create ", "update ", "insert "]

    def _safety_check(self, sql: str) -> None:
        lowered = (sql or "").strip().lower()
        if not lowered:
            raise QueryExecutionError("Empty SQL string")
        for token in self.forbidden_tokens:
            if token in lowered:
                raise QueryExecutionError(f"Destructive SQL token detected: {token.strip()}")

    def run(self, sql: str, params: Optional[Dict[str, Any]] = None, capture_df: bool = True) -> Dict[str, Any]:
        entry = {
            "sql": sql,
            "params": params,
            "timestamp": now_utc_iso(),
            "success": False,
            "rowcount": 0,
            "exec_time_s": None,
            "error": None,
            "columns": None,
            "result_preview": None,
        }
        try:
            self._safety_check(sql)
            start = datetime.now(timezone.utc)
            with safe_connection(self.engine) as conn:
                result = conn.execute(sqlalchemy.text(sql), params or {})
                rows = result.fetchall()
                cols = list(result.keys())
            end = datetime.now(timezone.utc)
            exec_time = (end - start).total_seconds()
            df = None
            if capture_df:
                df = pd.DataFrame(rows, columns=cols)
                if len(df) > self.max_rows:
                    df = df.iloc[: self.max_rows]
            entry.update({
                "success": True,
                "rowcount": min(len(rows), self.max_rows),
                "exec_time_s": exec_time,
                "columns": cols,
                "result_preview": df
            })
        except Exception as e:
            entry.update({
                "error": str(e),
                "success": False
            })
        finally:
            self.history.append(entry)
        return entry

    def last(self):
        return self.history[-1] if self.history else None

    def save_history(self, path: str):
        serializable = []
        for h in self.history:
            s = {k: v for k, v in h.items() if k != "result_preview"}
            serializable.append(s)
        with open(path, "w", encoding="utf-8") as f:
            json.dump(serializable, f, indent=2, default=str)

## QueryRunner quick test

In [33]:
qr = QueryRunner(engine, max_rows=200)
test_sql = "SELECT customerNumber, customerName, country FROM customers LIMIT 10;"
meta = qr.run(test_sql)
print("Success:", meta["success"])
if meta["success"]:
    display(meta["result_preview"])
else:
    print("Error:", meta["error"])

# List and display schema
for table_name in list_tables(engine):
    print(f"\nSchema for table: {table_name}")
    df_columns = get_table_columns(engine, table_name)
    display(df_columns)

Success: True


Unnamed: 0,customerNumber,customerName,country
0,103,Atelier graphique,France
1,112,Signal Gift Stores,USA
2,114,"Australian Collectors, Co.",Australia
3,119,La Rochelle Gifts,France
4,121,Baane Mini Imports,Norway
5,124,Mini Gifts Distributors Ltd.,USA
6,125,Havel & Zbyszek Co,Poland
7,128,"Blauer See Auto, Co.",Germany
8,129,Mini Wheels Co.,USA
9,131,Land of Toys Inc.,USA



Schema for table: customers


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,customerNumber,int,NO,PRI
1,customerName,varchar,NO,
2,contactLastName,varchar,NO,
3,contactFirstName,varchar,NO,
4,phone,varchar,NO,
5,addressLine1,varchar,NO,
6,addressLine2,varchar,YES,
7,city,varchar,NO,
8,state,varchar,YES,
9,postalCode,varchar,YES,



Schema for table: employees


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,employeeNumber,int,NO,PRI
1,lastName,varchar,NO,
2,firstName,varchar,NO,
3,extension,varchar,NO,
4,email,varchar,NO,
5,officeCode,varchar,NO,MUL
6,reportsTo,int,YES,MUL
7,jobTitle,varchar,NO,



Schema for table: offices


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,officeCode,varchar,NO,PRI
1,city,varchar,NO,
2,phone,varchar,NO,
3,addressLine1,varchar,NO,
4,addressLine2,varchar,YES,
5,state,varchar,YES,
6,country,varchar,NO,
7,postalCode,varchar,NO,
8,territory,varchar,NO,



Schema for table: orderdetails


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,orderNumber,int,NO,PRI
1,productCode,varchar,NO,PRI
2,quantityOrdered,int,NO,
3,priceEach,decimal,NO,
4,orderLineNumber,smallint,NO,



Schema for table: orders


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,orderNumber,int,NO,PRI
1,orderDate,date,NO,
2,requiredDate,date,NO,
3,shippedDate,date,YES,
4,status,varchar,NO,
5,comments,text,YES,
6,customerNumber,int,NO,MUL



Schema for table: payments


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,customerNumber,int,NO,PRI
1,checkNumber,varchar,NO,PRI
2,paymentDate,date,NO,
3,amount,decimal,NO,



Schema for table: productlines


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,productLine,varchar,NO,PRI
1,textDescription,varchar,YES,
2,htmlDescription,mediumtext,YES,
3,image,mediumblob,YES,



Schema for table: products


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,productCode,varchar,NO,PRI
1,productName,varchar,NO,
2,productLine,varchar,NO,MUL
3,productScale,varchar,NO,
4,productVendor,varchar,NO,
5,productDescription,text,NO,
6,quantityInStock,smallint,NO,
7,buyPrice,decimal,NO,
8,MSRP,decimal,NO,


## Dataset validation helper
Run the static classicmodels test set against the live DB.

In [None]:
def validate_test_set(path: str = "data/classicmodels_test_200.json", limit: Optional[int] = None):
    import json
    with open(path, "r", encoding="utf-8") as f:
        items = json.load(f)
    if limit:
        items = items[:limit]

    qr = QueryRunner(engine, max_rows=200)
    successes = []
    failures = []
    for idx, item in enumerate(items):
        meta = qr.run(item["sql"], capture_df=False)
        if meta["success"]:
            successes.append(idx)
        else:
            failures.append({
                "index": idx,
                "nlq": item.get("nlq"),
                "sql": item.get("sql"),
                "error": meta["error"],
            })
    print(f"Ran {len(items)} queries. Success: {len(successes)}. Failures: {len(failures)}.")
    if failures:
        print("Failures (first 5):")
        for f in failures[:5]:
            print(f)
    else:
        print("All queries succeeded in this run.")
    return successes, failures

## Load static test set
Use the fixed 200-sample NLQ-SQL pairs from data/classicmodels_test_200.json.

In [None]:
import json
with open('data/classicmodels_test_200.json', 'r', encoding='utf-8') as f:
    test_set = json.load(f)
print(f"Loaded {len(test_set)} test items from data/classicmodels_test_200.json")

Loaded 200 test items from data/classicmodels_test_200.json


## Load base model/tokenizer (pre-QLoRA placeholder)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=True
)

print(f"Tokenizer and model '{model_id}' loaded successfully.")