In [1]:
pip install transformers sqlglot accelerate



In [2]:
# eval_arctic_hospital1_colab.py
import json, time, csv, sqlite3
from pathlib import Path

import torch
import sqlglot
from sqlglot import parse_one
from transformers import AutoTokenizer, AutoModelForCausalLM

# ============== CONFIG (Colab main dir) ==============
EVAL_JSONL  = Path("eval_hospital_1.jsonl")
SQLITE_DB   = Path("hospital_1.sqlite")     # <- use local file directly
RESULTS_CSV = Path("results_hospital_1_arctic.csv")

MODEL_NAME = "Snowflake/Arctic-Text2SQL-R1-7B"  # change if needed

GEN_KW = dict(
    max_new_tokens=256,
    temperature=0.0,
    top_p=1.0,
    do_sample=False,
    num_beams=1,
    repetition_penalty=1.05,
)

MAX_INPUT_TOKENS = 3500

if torch.cuda.is_available():
    DEVICE = "cuda"
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
# =====================================================

SYS_INSTR = (
    "You are a Text-to-SQL system. "
    "Given a question and a database schema, output one valid SQLite SQL query. "
    "Use only tables/columns from the schema. Prefer explicit JOIN ... ON ... . "
    "Return only the SQL; no explanation."
)

def build_prompt(question: str, schema: str) -> str:
    # simple chat-style prompt for decoder-only models
    return (
        f"<|system|>\n{SYS_INSTR}\n<|end|>\n"
        f"<|user|>\nQuestion: {question}\n\nSchema:\n{schema}\n\nSQL:\n<|end|>\n"
        f"<|assistant|>\n"
    )

def load_eval(jsonl_path: Path):
    rows = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                ex = json.loads(line)
                # force sqlite path to the local file regardless of JSON content
                ex["sqlite_path"] = str(SQLITE_DB)
                rows.append(ex)
    return rows

def canonical_sql(sql_text: str) -> str | None:
    try:
        ast = parse_one(sql_text, read="sqlite")
        return ast.sql(dialect="sqlite", pretty=False)
    except Exception:
        return None

def extract_sql(text: str) -> str:
    t = text.strip()
    if "```" in t:
        for seg in t.split("```"):
            if "select" in seg.lower() or "with " in seg.lower():
                t = seg.strip()
                break
    for head in ("sql:", "answer:", "query:", "<|assistant|>",):
        if t.lower().startswith(head):
            t = t[len(head):].strip()
    if ";" in t:
        t = t.split(";", 1)[0] + ";"
    return t.strip()

def try_execute(conn, sql_text: str):
    try:
        cur = conn.execute(sql_text)
        rows = cur.fetchall()
        normd = []
        for r in rows:
            nr = []
            for v in r:
                nr.append(round(v, 6) if isinstance(v, float) else v)
            normd.append(tuple(nr))
        return set(normd), None
    except Exception as e:
        return None, str(e)

def truncate_to_tokens(tokenizer, text: str, max_tokens: int) -> str:
    ids = tokenizer(text, add_special_tokens=False)["input_ids"]
    if len(ids) <= max_tokens:
        return text
    ids = ids[-max_tokens:]  # keep tail (question+schema end)
    return tokenizer.decode(ids, skip_special_tokens=True)

def main():
    assert EVAL_JSONL.exists(), f"Missing {EVAL_JSONL} (upload it to Colab)"
    assert SQLITE_DB.exists(),  f"Missing {SQLITE_DB} (upload DB file to Colab)"

    data = load_eval(EVAL_JSONL)

    # DB
    conn = sqlite3.connect(str(SQLITE_DB))
    conn.execute("PRAGMA foreign_keys=ON")

    # Model
    print(f"Loading {MODEL_NAME} on {DEVICE} ...")
    dtype = torch.float16 if DEVICE in ("cuda", "mps") else torch.float32
    tok = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype)
    model.to(DEVICE)

    # pad token safety
    if model.config.pad_token_id is None and tok.pad_token_id is not None:
        model.config.pad_token_id = tok.pad_token_id
    if tok.pad_token_id is None and tok.eos_token_id is not None:
        tok.pad_token_id = tok.eos_token_id

    results = []
    n = len(data)
    em_cnt = ex_cnt = valid_cnt = 0
    latencies = []

    print(f"Evaluating on {n} examples from {SQLITE_DB.name} ...")

    for i, ex in enumerate(data, 1):
        q = ex["question"]
        schema = ex["schema_serialized"]
        gold_sql = ex["gold_query"]

        prompt = build_prompt(q, schema)
        prompt = truncate_to_tokens(tok, prompt, MAX_INPUT_TOKENS)

        t0 = time.time()
        inputs = tok(prompt, return_tensors="pt").to(DEVICE)
        gen = model.generate(**inputs, **GEN_KW)
        gen_ms = (time.time() - t0) * 1000.0

        text_out = tok.decode(gen[0], skip_special_tokens=True)
        if "<|assistant|>" in text_out:
            text_out = text_out.split("<|assistant|>", 1)[-1]
        pred_sql_raw = extract_sql(text_out)

        pred_sql_norm = canonical_sql(pred_sql_raw)
        gold_sql_norm = canonical_sql(gold_sql)

        em = int(pred_sql_norm is not None and gold_sql_norm is not None and pred_sql_norm == gold_sql_norm)

        valid = 0
        ex_ok = 0
        err = None

        if pred_sql_norm is not None:
            pred_rows, err = try_execute(conn, pred_sql_norm)
            if pred_rows is not None:
                valid = 1
                gold_rows, gerr = try_execute(conn, gold_sql_norm or gold_sql)
                if gold_rows is not None:
                    ex_ok = int(pred_rows == gold_rows)
                else:
                    err = f"Gold failed: {gerr}"
        else:
            err = "ParseError"

        em_cnt += em
        ex_cnt += ex_ok
        valid_cnt += valid
        latencies.append(gen_ms)

        results.append({
            "id": ex["id"],
            "question": q,
            "gold_sql": gold_sql,
            "pred_sql_raw": pred_sql_raw,
            "pred_sql_norm": pred_sql_norm or "",
            "em": em,
            "ex": ex_ok,
            "valid_sql": valid,
            "latency_ms": round(gen_ms, 2),
            "error": err or ""
        })

        if i % 10 == 0 or i == n:
            print(f"[{i}/{n}] EM={em_cnt/i:.3f} EX={ex_cnt/i:.3f} Valid={valid_cnt/i:.3f}")

    # save
    with open(RESULTS_CSV, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=list(results[0].keys()))
        w.writeheader()
        w.writerows(results)

    print("\n=== SUMMARY ===")
    print(f"Model: {MODEL_NAME}")
    print(f"Examples: {n}")
    print(f"Exact Match (EM):       {em_cnt/n:.3%}")
    print(f"Execution Accuracy (EX):{ex_cnt/n:.3%}")
    print(f"Valid-SQL rate:         {valid_cnt/n:.3%}")
    med = sorted(latencies)[len(latencies)//2]
    print(f"Median gen latency:     {med:.1f} ms")
    print(f"Saved: {RESULTS_CSV}")

if __name__ == "__main__":
    main()


Loading Snowflake/Arctic-Text2SQL-R1-7B on cuda ...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/845 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/466M [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/121 [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Evaluating on 100 examples from hospital_1.sqlite ...
[10/100] EM=0.000 EX=0.400 Valid=0.900
[20/100] EM=0.000 EX=0.700 Valid=0.950
[30/100] EM=0.000 EX=0.767 Valid=0.967
[40/100] EM=0.000 EX=0.725 Valid=0.975
[50/100] EM=0.000 EX=0.720 Valid=0.980
[60/100] EM=0.000 EX=0.733 Valid=0.983
[70/100] EM=0.000 EX=0.686 Valid=0.971
[80/100] EM=0.000 EX=0.700 Valid=0.975
[90/100] EM=0.000 EX=0.700 Valid=0.978
[100/100] EM=0.000 EX=0.710 Valid=0.980

=== SUMMARY ===
Model: Snowflake/Arctic-Text2SQL-R1-7B
Examples: 100
Exact Match (EM):       0.000%
Execution Accuracy (EX):71.000%
Valid-SQL rate:         98.000%
Median gen latency:     1657.0 ms
Saved: results_hospital_1_arctic.csv
