In [2]:
# eval_flan_hospital1.py
import json, time, csv, sqlite3
from pathlib import Path
import sqlglot
from sqlglot import parse_one
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# ---------- CONFIG ----------
EVAL_JSONL = Path("test_hospital_1.jsonl")  # Use the clean test file you just created
RESULTS_CSV = Path("results_hospital_1_flan.csv")
MODEL_NAME = "juierror/flan-t5-text2sql-with-schema-v2"

GEN_KW = dict(
    max_new_tokens=256,
    num_beams=4,
    do_sample=False,
    no_repeat_ngram_size=3,
)

# Better device detection for MacBook
if torch.cuda.is_available():
    DEVICE = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    DEVICE = "mps"  # Apple Silicon GPU
else:
    DEVICE = "cpu"

print(f"Using device: {DEVICE}")

# ----------------------------

INSTR = (
    "Question: {question}\n\n"
    "{schema}\n\n"
    "SQL:"
)

def load_eval(jsonl_path):
    rows = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows

def canonical_sql(sql_text):
    """Normalize SQL to canonical form for EM comparison."""
    if not sql_text:
        return None
    try:
        ast = parse_one(sql_text, read="sqlite")
        return ast.sql(dialect="sqlite", pretty=False)
    except Exception:
        return None

def extract_sql(text):
    """Extract SQL from model output."""
    t = text.strip()
    
    # Remove code fences if present
    if "```" in t:
        parts = t.split("```")
        for seg in parts:
            seg_l = seg.lower()
            if "select" in seg_l or "with " in seg_l:
                t = seg.strip()
                # Remove "sql" language marker if present
                if t.lower().startswith("sql"):
                    t = t[3:].strip()
                break
    
    # Drop leading labels
    for head in ["sql:", "answer:", "query:", "sqlite:"]:
        if t.lower().startswith(head):
            t = t[len(head):].strip()
    
    # Take up to first semicolon if present
    if ";" in t:
        t = t.split(";", 1)[0] + ";"
    
    return t.strip()

def try_execute(conn, sql_text):
    """Execute SQL and return normalized result set."""
    try:
        cur = conn.execute(sql_text)
        rows = cur.fetchall()
        
        # Normalize rows
        normalized = []
        for row in rows:
            norm_row = []
            for val in row:
                if isinstance(val, float):
                    norm_row.append(round(val, 6))
                else:
                    norm_row.append(val)
            normalized.append(tuple(norm_row))
        
        return set(normalized), None
    except Exception as e:
        return None, str(e)

def main():
    # Check file exists
    if not EVAL_JSONL.exists():
        print(f"ERROR: {EVAL_JSONL} not found!")
        print(f"Make sure you have the test_hospital_1.jsonl file in the current directory.")
        return
    
    data = load_eval(EVAL_JSONL)
    print(f"Loaded {len(data)} examples from {EVAL_JSONL}")
    
    # Get SQLite path from first example
    sqlite_path = Path(data[0]["sqlite_path"])
    if not sqlite_path.exists():
        print(f"ERROR: SQLite database not found: {sqlite_path}")
        return
    
    print(f"Connecting to database: {sqlite_path}")
    conn = sqlite3.connect(str(sqlite_path))
    conn.execute("PRAGMA foreign_keys=ON")
    
    # Load model
    print(f"Loading model: {MODEL_NAME} on {DEVICE}...")
    tok = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
    
    # Move to device ONLY if not CPU
    if DEVICE != "cpu":
        model.to(DEVICE)
    
    model.eval()  # Set to evaluation mode
    
    results = []
    n = len(data)
    valid_cnt = em_cnt = ex_cnt = 0
    latencies = []
    
    print(f"\nEvaluating on {n} examples...")
    print("-" * 50)
    
    for i, ex in enumerate(data, 1):
        question = ex["question"]
        gold_sql = ex["gold_query"]
        schema_str = ex["schema_serialized"]
        
        # Build prompt
        prompt = INSTR.format(question=question, schema=schema_str)
        
        # Generate
        t0 = time.time()
        inp = tok(prompt, return_tensors="pt", truncation=True, max_length=512)
        
        # Move input to device if needed
        if DEVICE != "cpu":
            inp = {k: v.to(DEVICE) for k, v in inp.items()}
        
        out_ids = model.generate(**inp, **GEN_KW)
        gen_ms = (time.time() - t0) * 1000.0
        
        pred_text = tok.decode(out_ids[0], skip_special_tokens=True)
        pred_sql_raw = extract_sql(pred_text)
        
        # Normalize for EM
        pred_sql_norm = canonical_sql(pred_sql_raw)
        gold_sql_norm = canonical_sql(gold_sql)
        
        em = int(
            pred_sql_norm is not None and 
            gold_sql_norm is not None and 
            pred_sql_norm == gold_sql_norm
        )
        
        # Execution accuracy
        valid = 0
        ex_ok = 0
        exec_err = None
        
        if pred_sql_norm is not None:
            pred_rows, exec_err = try_execute(conn, pred_sql_norm)
            if pred_rows is not None:
                valid = 1
                gold_rows, gold_err = try_execute(conn, gold_sql_norm or gold_sql)
                if gold_rows is not None:
                    ex_ok = int(pred_rows == gold_rows)
                else:
                    exec_err = f"Gold failed: {gold_err}"
        else:
            exec_err = "ParseError"
        
        # Accumulate metrics
        valid_cnt += valid
        em_cnt += em
        ex_cnt += ex_ok
        latencies.append(gen_ms)
        
        results.append({
            "id": ex["id"],
            "question": question,
            "gold_sql": gold_sql,
            "pred_sql_raw": pred_sql_raw,
            "pred_sql_norm": pred_sql_norm or "",
            "em": em,
            "ex": ex_ok,
            "valid_sql": valid,
            "latency_ms": round(gen_ms, 2),
            "error": exec_err or ""
        })
        
        # Progress updates
        if i % 10 == 0 or i == n:
            print(f"[{i}/{n}] EM={em_cnt/i:.3f} EX={ex_cnt/i:.3f} Valid={valid_cnt/i:.3f}")
    
    # Save results
    with open(RESULTS_CSV, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=list(results[0].keys()))
        w.writeheader()
        w.writerows(results)
    
    # Summary
    em_rate = em_cnt / n
    ex_rate = ex_cnt / n
    valid_rate = valid_cnt / n
    med_latency = sorted(latencies)[len(latencies)//2]
    
    print("\n" + "=" * 50)
    print("=== SUMMARY ===")
    print(f"Model: {MODEL_NAME}")
    print(f"Examples: {n}")
    print(f"Exact Match (EM):        {em_rate:.3%}")
    print(f"Execution Accuracy (EX): {ex_rate:.3%}")
    print(f"Valid-SQL rate:          {valid_rate:.3%}")
    print(f"Median gen latency:      {med_latency:.1f} ms")
    print(f"Saved: {RESULTS_CSV}")
    print("=" * 50)

if __name__ == "__main__":
    main()

Using device: mps
Loaded 100 examples from test_hospital_1.jsonl
Connecting to database: spider_data/database/hospital_1/hospital_1.sqlite
Loading model: juierror/flan-t5-text2sql-with-schema-v2 on mps...

Evaluating on 100 examples...
--------------------------------------------------
[10/100] EM=0.000 EX=0.200 Valid=0.200
[20/100] EM=0.150 EX=0.400 Valid=0.400
[30/100] EM=0.167 EX=0.367 Valid=0.433
[40/100] EM=0.150 EX=0.300 Valid=0.400
[50/100] EM=0.180 EX=0.300 Valid=0.400
[60/100] EM=0.167 EX=0.350 Valid=0.467
[70/100] EM=0.157 EX=0.314 Valid=0.457
[80/100] EM=0.163 EX=0.300 Valid=0.425
[90/100] EM=0.144 EX=0.267 Valid=0.378
[100/100] EM=0.150 EX=0.280 Valid=0.380

=== SUMMARY ===
Model: juierror/flan-t5-text2sql-with-schema-v2
Examples: 100
Exact Match (EM):        15.000%
Execution Accuracy (EX): 28.000%
Valid-SQL rate:          38.000%
Median gen latency:      2778.4 ms
Saved: results_hospital_1_flan.csv
