# 1. GCP Auth (optional for Colab)

If you're in Colab, this will show a login prompt.

If you're running locally (VS Code / Jupyter), you can ignore the message and rely on gcloud / ADC / service accounts instead.

In [1]:
try:
    from google.colab import auth
except ModuleNotFoundError:
    auth = None

if auth:
    auth.authenticate_user()
else:
    print("Not running in Colab; ensure GCP auth via gcloud/ADC or service account if needed.")

# 2. Project configuration

In [2]:
import os

project_id = "modified-enigma-476414-h9"  # TODO: change or move to env var in production
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id

print("GOOGLE_CLOUD_PROJECT set to:", os.environ["GOOGLE_CLOUD_PROJECT"])


# 3. Install dependencies

If you've already installed from requirements.txt in this environment, you can comment this out.

In [3]:
import sys
# Prefer installing from pinned requirements.txt for reproducibility
!{sys.executable} -m pip install --upgrade pip
!{sys.executable} -m pip install -r requirements.txt


Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.3
[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m[31m
[0m

#4. Imports & logging

In [None]:
import os
import logging
import json
from datetime import datetime, timezone
from typing import Any, Dict, Optional

import pandas as pd
import sqlalchemy
from sqlalchemy import text
from sqlalchemy.engine import Engine

from google.cloud.sql.connector import Connector
from google.api_core import retry

import pymysql

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("nl2sql_db")

# 5. Connection parameters

In [None]:
from getpass import getpass

INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = os.getenv("DB_PASS")
DB_NAME = os.getenv("DB_NAME", "classicmodels")

if not INSTANCE_CONNECTION_NAME:
    INSTANCE_CONNECTION_NAME = input("Enter INSTANCE_CONNECTION_NAME: ").strip()

if not DB_USER:
    DB_USER = input("Enter DB_USER: ").strip()

if not DB_PASS:
    DB_PASS = getpass("Enter DB_PASS: ")

print("Using DB:", DB_NAME)

Enter INSTANCE_CONNECTION_NAME: modified-enigma-476414-h9:europe-west2:classicmodels
Enter DB_USER: root
Enter DB_PASS: ··········


# 6. Cloud SQL connector + engine

In [None]:
from contextlib import contextmanager

connector = Connector()

def getconn():
    """SQLAlchemy creator hook using the Cloud SQL connector."""
    return connector.connect(
        INSTANCE_CONNECTION_NAME,
        "pymysql",
        user=DB_USER,
        password=DB_PASS,
        db=DB_NAME,
    )

# Use a creator function so SQLAlchemy delegates connection creation to the connector
engine: Engine = sqlalchemy.create_engine(
    "mysql+pymysql://",
    creator=getconn,
    future=True
)

@contextmanager
def safe_connection(engine: Engine):
    """
    Context manager that yields a DB connection and ensures it gets closed.
    """
    conn = None
    try:
        conn = engine.connect()
        yield conn
    finally:
        if conn is not None:
            conn.close()



# 7. Schema exploration helpers

In [None]:
def list_tables(engine: Engine) -> list:
    """Return a list of table names in the current database."""
    with safe_connection(engine) as conn:
        result = conn.execute(text("SHOW TABLES;")).fetchall()
    return [r[0] for r in result]

def get_table_columns(engine: Engine, table_name: str) -> pd.DataFrame:
    """
    Return a DataFrame of columns for a given table.
    Includes column name, data type, nullability, and key info.
    """
    query = text("""
        SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY
        FROM INFORMATION_SCHEMA.COLUMNS
        WHERE TABLE_SCHEMA = :db AND TABLE_NAME = :table
        ORDER BY ORDINAL_POSITION
    """)
    with safe_connection(engine) as conn:
        df = pd.read_sql(query, conn, params={"db": DB_NAME, "table": table_name})
    return df


# 8. QueryRunner (read-only executor)

In [None]:
class QueryExecutionError(Exception):
    pass

def now_utc_iso() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")

class QueryRunner:
    """
    Execute generated SQL safely against the engine, capture results and metadata,
    and keep a history suitable for evaluation and error analysis.
    """
    def __init__(self, engine: Engine, max_rows: int = 1000, forbidden_tokens=None):
        self.engine = engine
        self.max_rows = max_rows
        self.history = []
        self.forbidden_tokens = forbidden_tokens or [
            "drop ", "delete ", "truncate ", "alter ", "create ", "update ", "insert "
        ]

    def _safety_check(self, sql: str) -> None:
        """Block obviously destructive statements."""
        lowered = (sql or "").strip().lower()
        if not lowered:
            raise QueryExecutionError("Empty SQL string")
        for token in self.forbidden_tokens:
            if token in lowered:
                raise QueryExecutionError(f"Destructive SQL token detected: {token.strip()}")

    def run(
        self,
        sql: str,
        params: Optional[Dict[str, Any]] = None,
        capture_df: bool = True,
    ) -> Dict[str, Any]:
        """
        Execute a SELECT-style query, returning metadata and an optional DataFrame preview.
        """
        entry = {
            "sql": sql,
            "params": params,
            "timestamp": now_utc_iso(),
            "success": False,
            "rowcount": 0,
            "exec_time_s": None,
            "error": None,
            "columns": None,
            "result_preview": None,
        }
        try:
            self._safety_check(sql)
            start = datetime.now(timezone.utc)

            with safe_connection(self.engine) as conn:
                result = conn.execute(sqlalchemy.text(sql), params or {})
                rows = result.fetchall()
                cols = list(result.keys())

            end = datetime.now(timezone.utc)
            exec_time = (end - start).total_seconds()

            df = None
            if capture_df:
                df = pd.DataFrame(rows, columns=cols)
                if len(df) > self.max_rows:
                    df = df.iloc[: self.max_rows]

            entry.update({
                "success": True,
                "rowcount": min(len(rows), self.max_rows),
                "exec_time_s": exec_time,
                "columns": cols,
                "result_preview": df,
            })
        except Exception as e:
            entry.update({
                "error": str(e),
                "success": False,
            })
        finally:
            self.history.append(entry)
        return entry

    def last(self):
        return self.history[-1] if self.history else None

    def save_history(self, path: str):
        """Persist history (without DataFrames) to JSON for later analysis."""
        serializable = []
        for h in self.history:
            s = {k: v for k, v in h.items() if k != "result_preview"}
            serializable.append(s)
        with open(path, "w", encoding="utf-8") as f:
            json.dump(serializable, f, indent=2, default=str)

Unnamed: 0,customerNumber,customerName,country
0,103,Atelier graphique,France
1,112,Signal Gift Stores,USA
2,114,"Australian Collectors, Co.",Australia
3,119,La Rochelle Gifts,France
4,121,Baane Mini Imports,Norway


# 9. Smoke tests (DB connectivity + schema)

In [None]:
def fetch_sample_customers(limit: int = 10) -> pd.DataFrame:
    """Quick sample query against the customers table."""
    q = text("SELECT customerNumber, customerName, country FROM customers LIMIT :limit;")
    with safe_connection(engine) as conn:
        df = pd.read_sql(q, conn, params={"limit": limit})
    return df

try:
    tables = list_tables(engine)
    logger.info("Tables in classicmodels: %s", tables)

    sample_df = fetch_sample_customers(5)
    display(sample_df)

    # Optionally print each table's schema (comment out if too verbose)
    for table_name in tables:
        print(f"\nSchema for table: {table_name}")
        df_columns = get_table_columns(engine, table_name)
        display(df_columns)

except Exception as e:
    logger.exception("Smoke test failed: %s", e)

# 10. Load static NLQ-SQL test set

In [None]:
with open("data/classicmodels_test_200.json", "r", encoding="utf-8") as f:
    test_set = json.load(f)

print(f"Loaded {len(test_set)} test items from data/classicmodels_test_200.json")

Success: True


Unnamed: 0,customerNumber,customerName,country
0,103,Atelier graphique,France
1,112,Signal Gift Stores,USA
2,114,"Australian Collectors, Co.",Australia
3,119,La Rochelle Gifts,France
4,121,Baane Mini Imports,Norway
5,124,Mini Gifts Distributors Ltd.,USA
6,125,Havel & Zbyszek Co,Poland
7,128,"Blauer See Auto, Co.",Germany
8,129,Mini Wheels Co.,USA
9,131,Land of Toys Inc.,USA



Schema for table: customers


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,customerNumber,int,NO,PRI
1,customerName,varchar,NO,
2,contactLastName,varchar,NO,
3,contactFirstName,varchar,NO,
4,phone,varchar,NO,
5,addressLine1,varchar,NO,
6,addressLine2,varchar,YES,
7,city,varchar,NO,
8,state,varchar,YES,
9,postalCode,varchar,YES,



Schema for table: employees


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,employeeNumber,int,NO,PRI
1,lastName,varchar,NO,
2,firstName,varchar,NO,
3,extension,varchar,NO,
4,email,varchar,NO,
5,officeCode,varchar,NO,MUL
6,reportsTo,int,YES,MUL
7,jobTitle,varchar,NO,



Schema for table: offices


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,officeCode,varchar,NO,PRI
1,city,varchar,NO,
2,phone,varchar,NO,
3,addressLine1,varchar,NO,
4,addressLine2,varchar,YES,
5,state,varchar,YES,
6,country,varchar,NO,
7,postalCode,varchar,NO,
8,territory,varchar,NO,



Schema for table: orderdetails


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,orderNumber,int,NO,PRI
1,productCode,varchar,NO,PRI
2,quantityOrdered,int,NO,
3,priceEach,decimal,NO,
4,orderLineNumber,smallint,NO,



Schema for table: orders


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,orderNumber,int,NO,PRI
1,orderDate,date,NO,
2,requiredDate,date,NO,
3,shippedDate,date,YES,
4,status,varchar,NO,
5,comments,text,YES,
6,customerNumber,int,NO,MUL



Schema for table: payments


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,customerNumber,int,NO,PRI
1,checkNumber,varchar,NO,PRI
2,paymentDate,date,NO,
3,amount,decimal,NO,



Schema for table: productlines


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,productLine,varchar,NO,PRI
1,textDescription,varchar,YES,
2,htmlDescription,mediumtext,YES,
3,image,mediumblob,YES,



Schema for table: products


Unnamed: 0,COLUMN_NAME,DATA_TYPE,IS_NULLABLE,COLUMN_KEY
0,productCode,varchar,NO,PRI
1,productName,varchar,NO,
2,productLine,varchar,NO,MUL
3,productScale,varchar,NO,
4,productVendor,varchar,NO,
5,productDescription,text,NO,
6,quantityInStock,smallint,NO,
7,buyPrice,decimal,NO,
8,MSRP,decimal,NO,


# 11. Optional: validate test SQL against the live DB


In [None]:
from typing import List, Tuple

def validate_test_set(path: str = "data/classicmodels_test_200.json",
                      limit: Optional[int] = None) -> Tuple[list, list]:
    with open(path, "r", encoding="utf-8") as f:
        items = json.load(f)
    if limit:
        items = items[:limit]

    qr_local = QueryRunner(engine, max_rows=200)
    successes: List[int] = []
    failures: List[Dict[str, Any]] = []

    for idx, item in enumerate(items):
        meta = qr_local.run(item["sql"], capture_df=False)
        if meta["success"]:
            successes.append(idx)
        else:
            failures.append({
                "index": idx,
                "nlq": item.get("nlq"),
                "sql": item.get("sql"),
                "error": meta["error"],
            })

    print(f"Ran {len(items)} queries. Success: {len(successes)}. Failures: {len(failures)}.")
    if failures:
        print("Failures (first 5):")
        for f in failures[:5]:
            print(f)
    else:
        print("All queries succeeded in this run.")
    return successes, failures

# Uncomment to run a quick validation (e.g. on first 50)
# successes, failures = validate_test_set(limit=50)

# 12. Hugging Face authentication

In [None]:
from huggingface_hub import notebook_login
notebook_login()

# 13. Load Llama-3-8B-Instruct (4-bit where possible)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# BitsAndBytesConfig is optional; if something breaks we fall back to non-quantized.
try:
    from transformers import BitsAndBytesConfig
    HAS_BNB = True
except ImportError:
    HAS_BNB = False

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

if hf_token is None:
    raise RuntimeError(
        "HF_TOKEN is not set. Run huggingface_hub.login() or set os.environ['HF_TOKEN']."
    )

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)

# Some Llama models don't have a pad token set; fallback to eos_token if needed.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

use_4bit = HAS_BNB and torch.cuda.is_available()

if use_4bit:
    print("Using 4-bit quantization with bitsandbytes.")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        token=hf_token,
    )
else:
    print("4-bit not available (no GPU or bitsandbytes missing). Loading in full precision on CPU/GPU.")
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto" if torch.cuda.is_available() else None,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        token=hf_token,
    )

model.eval()
print(f"Tokenizer and model '{model_id}' loaded successfully.")
print("Device:", next(model.parameters()).device)