In [1]:
!pip -q install datasets transformers peft accelerate bitsandbytes sqlglot sqlite-utils tqdm


In [2]:
import os, json, re, sqlite3, io, zipfile, pathlib, random
from pathlib import Path
from typing import Dict, List, Tuple
import pandas as pd
from datasets import Dataset, DatasetDict
from tqdm import tqdm
import sqlglot


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
SPIDER_DIR = Path("./spider_data")  

TRAIN_JSON  = SPIDER_DIR / "train_spider.json"
TRAIN_OTHERS_JSON = SPIDER_DIR / "train_others.json"   
DEV_JSON    = SPIDER_DIR / "dev.json"
TABLES_JSON = SPIDER_DIR / "tables.json"
DB_ROOT     = SPIDER_DIR / "database"  

PREPARED_DIR = SPIDER_DIR.parent / "prepared"
PREPARED_DIR.mkdir(parents=True, exist_ok=True)

for p in [TRAIN_JSON, DEV_JSON, TABLES_JSON, DB_ROOT]:
    print(p, "OK" if p.exists() else "MISSING")


spider_data/train_spider.json OK
spider_data/dev.json OK
spider_data/tables.json OK
spider_data/database OK


In [4]:
with open(TABLES_JSON, "r") as f:
    TABLES = json.load(f)

def build_schema_index(tables_json) -> Dict[str, dict]:
    idx = {}
    for db in tables_json:
        db_id = db["db_id"]
        table_names = db["table_names_original"]
        columns = db["column_names_original"]  # (table_idx, col_name)
        column_types = db["column_types"]
        pks = set(db["primary_keys"])          # column indices
        fks = db["foreign_keys"]               # list of [from_col_idx, to_col_idx]

        tcols = {t:[] for t in table_names}
        for col_idx, (t_i, c_name) in enumerate(columns):
            if t_i == -1:  # special star row
                continue
            is_pk = col_idx in pks
            tcols[table_names[t_i]].append((c_name, is_pk, column_types[col_idx]))

        fk_pairs=[]
        for fr, to in fks:
            fr_ti, fr_col = columns[fr]
            to_ti, to_col = columns[to]
            if fr_ti==-1 or to_ti==-1:
                continue
            fk_pairs.append((table_names[fr_ti], fr_col, table_names[to_ti], to_col))

        idx[db_id] = {
            "tables": table_names,
            "tcols": tcols,
            "fk_pairs": fk_pairs
        }
    return idx

SCHEMA_IDX = build_schema_index(TABLES)
len(SCHEMA_IDX), list(SCHEMA_IDX)[:5]


(166, ['perpetrator', 'college_2', 'flight_company', 'icfp_1', 'body_builder'])

In [5]:
def serialize_schema(db_id: str, max_tables: int = 8, max_cols_per_table: int = 10) -> str:
    s = SCHEMA_IDX[db_id]
    tables = s["tables"][:max_tables]
    tcols  = s["tcols"]
    fk     = s["fk_pairs"]

    lines = [f"Database: {db_id}", "Tables:"]
    for t in tables:
        cols = tcols[t][:max_cols_per_table]
        cols_str = ", ".join([("*"+c if is_pk else c) for c,is_pk,_ in cols for c in [c]])
        lines.append(f"  {t}({cols_str})")
    if fk:
        lines.append("Foreign Keys:")
        for (ft,fc,tt,tc) in fk[:12]:
            lines.append(f"  {ft}.{fc} -> {tt}.{tc}")
    return "\n".join(lines)

some_db = list(SCHEMA_IDX.keys())[0]
print(serialize_schema(some_db))


Database: perpetrator
Tables:
  perpetrator(*Perpetrator_ID, People_ID, Date, Year, Location, Country, Killed, Injured)
  people(*People_ID, Name, Height, Weight, Home Town)
Foreign Keys:
  perpetrator.People_ID -> people.People_ID


In [6]:
INSTR = (
"You are a helpful assistant that writes correct SQL for the given question and database schema. "
"Output ONLY the SQL query; do not include explanations."
)

def build_prompt(question: str, db_id: str) -> str:
    schema_txt = serialize_schema(db_id)
    return (
f"SYSTEM: {INSTR}\n"
f"USER:\n"
f"Question: {question}\n\n"
f"Schema:\n{schema_txt}\n\n"
f"Rules:\n"
f"- Only use tables/columns from {db_id}.\n"
f"- Output only SQL, no commentary.\n"
    )


In [7]:
def load_json_list(path: Path):
    with open(path, "r") as f:
        return json.load(f)

train_raw = load_json_list(TRAIN_JSON)
dev_raw   = load_json_list(DEV_JSON)

def to_examples(items):
    rows=[]
    for ex in items:
        q   = ex["question"]
        sql = ex["query"]  
        db  = ex["db_id"]
        rows.append({
            "db_id": db,
            "question": q,
            "prompt": build_prompt(q, db),
            "sql_gold": sql
        })
    return rows

train_rows = to_examples(train_raw)
dev_rows   = to_examples(dev_raw)

len(train_rows), len(dev_rows), train_rows[0].keys()


(7000, 1034, dict_keys(['db_id', 'question', 'prompt', 'sql_gold']))

In [8]:
ds = DatasetDict({
    "train": Dataset.from_list(train_rows),
    "dev":   Dataset.from_list(dev_rows),
})
print(ds)
print(ds["train"][0]["prompt"][:500], "\nGOLD:", ds["train"][0]["sql_gold"])


DatasetDict({
    train: Dataset({
        features: ['db_id', 'question', 'prompt', 'sql_gold'],
        num_rows: 7000
    })
    dev: Dataset({
        features: ['db_id', 'question', 'prompt', 'sql_gold'],
        num_rows: 1034
    })
})
SYSTEM: You are a helpful assistant that writes correct SQL for the given question and database schema. Output ONLY the SQL query; do not include explanations.
USER:
Question: How many heads of the departments are older than 56 ?

Schema:
Database: department_management
Tables:
  department(*Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees)
  head(*head_ID, name, born_state, age)
  management(*department_ID, head_ID, temporary_acting)
Foreign Keys:
  management.head_ID - 
GOLD: SELECT count(*) FROM head WHERE age  >  56


In [9]:
out = PREPARED_DIR
out.mkdir(parents=True, exist_ok=True)

train_out = out / "spider_train.jsonl"
dev_out   = out / "spider_dev.jsonl"

ds["train"].to_json(train_out, lines=True, orient="records", force_ascii=False)
ds["dev"].to_json(dev_out,     lines=True, orient="records", force_ascii=False)

print("Wrote:", train_out, dev_out)


Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 118.73ba/s]
Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 280.43ba/s]

Wrote: prepared/spider_train.jsonl prepared/spider_dev.jsonl





In [12]:
# Normalize SQL to a consistent style to make EM less sensitive to formatting
def normalize_sql(s: str, dialect_in: str = "mysql", dialect_out: str = "mysql") -> str:
    try:
        return sqlglot.transpile(
            s,
            read=dialect_in,
            write=dialect_out,
            normalize=True,
            pretty=False
        )[0]
    except Exception:
        return s

# Quick sanity: try to parse & normalize a small sample
bad_parse = 0
for i in range(50):
    sql = ds["dev"][i]["sql_gold"]
    try:
        _ = sqlglot.parse_one(sql)
    except Exception:
        bad_parse += 1
print(f"sqlglot could not parse {bad_parse}/50 dev gold queries")

print("Before:", ds["dev"][0]["sql_gold"])
print("After :", normalize_sql(ds["dev"][0]["sql_gold"]))


sqlglot could not parse 0/50 dev gold queries
Before: SELECT count(*) FROM singer
After : SELECT COUNT(*) FROM singer


In [15]:
# Safer slice: get a dict-of-lists, then index the "prompt" list
n = min(1000, len(ds["train"]))
subset = ds["train"][:n]            # dict of lists
prompts = subset["prompt"]          # list of strings

lens = [len(p) for p in prompts]
print(
    "Sampled prompt char lengths — min/avg/max:",
    min(lens), sum(lens)//len(lens), max(lens)
)

# Peek at a full prompt+gold pair (row-wise access uses integer index)
k = 3
row = ds["train"][k]                # single row dict
print(row["prompt"])
print("\nGOLD:", row["sql_gold"])


Sampled prompt char lengths — min/avg/max: 476 925 1585
SYSTEM: You are a helpful assistant that writes correct SQL for the given question and database schema. Output ONLY the SQL query; do not include explanations.
USER:
Question: What are the maximum and minimum budget of the departments?

Schema:
Database: department_management
Tables:
  department(*Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees)
  head(*head_ID, name, born_state, age)
  management(*department_ID, head_ID, temporary_acting)
Foreign Keys:
  management.head_ID -> head.head_ID
  management.department_ID -> department.Department_ID

Rules:
- Only use tables/columns from department_management.
- Output only SQL, no commentary.


GOLD: SELECT max(budget_in_billions) ,  min(budget_in_billions) FROM department


In [16]:
import sqlite3
import pandas as pd

def run_sqlite(db_root: Path, db_id: str, sql: str):
    db_path = db_root / db_id / f"{db_id}.sqlite"
    con = sqlite3.connect(db_path.as_posix())
    try:
        df = pd.read_sql_query(sql, con)
        return True, df
    except Exception as e:
        return False, str(e)
    finally:
        con.close()

# Test gold SQL executes for a few samples
ok_cnt, fail_cnt = 0, 0
errors = []
for i in range(20):
    ex = ds["dev"][i]
    ok, out = run_sqlite(DB_ROOT, ex["db_id"], ex["sql_gold"])
    if ok:
        ok_cnt += 1
    else:
        fail_cnt += 1
        errors.append((ex["db_id"], ex["question"], out))

print(f"Gold SQL execution — OK: {ok_cnt}, Fail: {fail_cnt}")
if errors:
    print("First error example:\nDB:", errors[0][0], "\nQ:", errors[0][1], "\nErr:", errors[0][2])


Gold SQL execution — OK: 20, Fail: 0
