In [1]:
import json
import csv
import time
import sqlite3
import os
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from collections import Counter
import sqlglot
from sqlglot import parse_one

# ============================================================
# CONFIGURATION
# ============================================================

# Model path
FINAL_MODEL_DIR = Path("finetuned_flant5/final_model")

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Prompt template
PROMPT_TEMPLATE = """Question: {question}

Schema:
{schema}

SQL:"""

# Generation parameters
MAX_INPUT_LENGTH = 512
GEN_MAX_LENGTH = 256
GEN_NUM_BEAMS = 4
GEN_TEMPERATURE = 0.0

# Test visualizations directory
TEST_DIR = Path("test_visualizations")

# Output directory
OUTPUT_DIR = Path("evaluation_results")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# All databases to evaluate
DATABASES = [
    "activity_1",
    "allergy_1",
    "cre_Doc_Tracking_DB",
    "customers_and_addresses",
    "department_store",
    "dorm_1",
    "driving_school",
    "flight_1",
    "movie_1",
    "network_2",
    "products_gen_characteristics"
]

print("=" * 80)
print("BATCH EVALUATION ON 11 DATABASES")
print("=" * 80)
print(f"Model: {FINAL_MODEL_DIR}")
print(f"Device: {DEVICE}")
print(f"Databases: {len(DATABASES)}")
print("=" * 80)

# ============================================================
# Load Fine-tuned Model
# ============================================================

print("\nLoading fine-tuned model...")
print("-" * 80)

tokenizer = AutoTokenizer.from_pretrained(str(FINAL_MODEL_DIR))
model = AutoModelForSeq2SeqLM.from_pretrained(str(FINAL_MODEL_DIR))
model = model.to(DEVICE)
model.eval()

print(f"‚úÖ Model loaded successfully")

# ============================================================
# SQL Utilities
# ============================================================

def canonical_sql(sql_text):
    """Normalize SQL to canonical form using sqlglot."""
    if not sql_text:
        return None
    try:
        ast = parse_one(sql_text, read="sqlite")
        return ast.sql(dialect="sqlite", pretty=False)
    except Exception:
        return None

def try_execute(conn, sql_text):
    """Execute SQL query and return result set."""
    try:
        cur = conn.execute(sql_text)
        rows = cur.fetchall()
        
        # Normalize floats
        normalized = []
        for row in rows:
            norm_row = []
            for val in row:
                if isinstance(val, float):
                    norm_row.append(round(val, 6))
                else:
                    norm_row.append(val)
            normalized.append(tuple(norm_row))
        
        return set(normalized), None
    
    except Exception as e:
        return None, str(e)

def extract_sql(text):
    """Extract SQL from model output."""
    text = text.strip()
    
    # Remove markdown code blocks if present
    if "```" in text:
        parts = text.split("```")
        for part in parts:
            if "select" in part.lower() or "SELECT" in part:
                text = part.strip()
                if text.lower().startswith("sql"):
                    text = text[3:].strip()
                break
    
    # Remove common prefixes
    for prefix in ["sql:", "answer:", "query:"]:
        if text.lower().startswith(prefix):
            text = text[len(prefix):].strip()
    
    # Ensure semicolon
    if ";" in text:
        text = text.split(";", 1)[0] + ";"
    
    return text.strip()

def load_jsonl(path):
    """Load JSONL file."""
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows

# ============================================================
# Evaluate Single Database
# ============================================================

def evaluate_database(db_name):
    """Evaluate model on a single database"""
    
    print(f"\n{'='*80}")
    print(f"EVALUATING: {db_name.upper()}")
    print(f"{'='*80}")
    
    # Paths
    test_file = TEST_DIR / f"test_{db_name}.jsonl"
    db_file = Path("spider_data/database") / db_name / f"{db_name}.sqlite"
    results_file = OUTPUT_DIR / f"results_{db_name}.csv"
    
    # Check files exist
    if not test_file.exists():
        print(f"‚ùå Test file not found: {test_file}")
        return None
    
    if not db_file.exists():
        print(f"‚ùå Database file not found: {db_file}")
        return None
    
    # Load test data
    test_data = load_jsonl(test_file)
    print(f"üìä Test examples: {len(test_data)}")
    
    # Connect to database
    conn = sqlite3.connect(str(db_file))
    conn.execute("PRAGMA foreign_keys=ON")
    print(f"‚úÖ Connected to database")
    
    # Evaluation metrics
    results = []
    n_examples = len(test_data)
    
    em_count = 0
    ex_count = 0
    valid_count = 0
    latencies = []
    
    print(f"\n{'‚îÄ'*80}")
    print("Running evaluation...")
    print(f"{'‚îÄ'*80}")
    
    for i, example in enumerate(test_data, 1):
        question = example['question']
        gold_sql = example['gold_query']
        schema = example['schema_serialized']
        
        # Build prompt
        prompt = PROMPT_TEMPLATE.format(question=question, schema=schema)
        
        # Tokenize
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            max_length=MAX_INPUT_LENGTH,
            truncation=True
        ).to(DEVICE)
        
        # Generate SQL
        start_time = time.time()
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=GEN_MAX_LENGTH,
                num_beams=GEN_NUM_BEAMS,
                temperature=GEN_TEMPERATURE if GEN_TEMPERATURE > 0 else 1.0,
                do_sample=False
            )
        
        gen_time_ms = (time.time() - start_time) * 1000.0
        latencies.append(gen_time_ms)
        
        # Decode
        pred_sql_raw = tokenizer.decode(outputs[0], skip_special_tokens=True)
        pred_sql_raw = extract_sql(pred_sql_raw)
        
        # Normalize
        pred_sql_norm = canonical_sql(pred_sql_raw)
        gold_sql_norm = canonical_sql(gold_sql)
        
        # Compute Metrics
        em = int(
            pred_sql_norm is not None and
            gold_sql_norm is not None and
            pred_sql_norm == gold_sql_norm
        )
        
        # Execution Accuracy (EX) and Valid SQL
        valid = 0
        ex_ok = 0
        error = None
        
        if pred_sql_norm is not None:
            pred_rows, error = try_execute(conn, pred_sql_norm)
            
            if pred_rows is not None:
                valid = 1
                
                gold_rows, gold_error = try_execute(conn, gold_sql_norm or gold_sql)
                
                if gold_rows is not None:
                    ex_ok = int(pred_rows == gold_rows)
                else:
                    error = f"Gold SQL failed: {gold_error}"
        else:
            error = "ParseError: Could not parse predicted SQL"
        
        # Update counters
        em_count += em
        ex_count += ex_ok
        valid_count += valid
        
        # Store result
        results.append({
            "id": example.get("id", f"{db_name}_{i}"),
            "dataset": db_name,
            "question": question,
            "gold_sql": gold_sql,
            "pred_sql_raw": pred_sql_raw,
            "pred_sql_norm": pred_sql_norm or "",
            "em": em,
            "ex": ex_ok,
            "valid_sql": valid,
            "latency_ms": round(gen_time_ms, 2),
            "error": error or ""
        })
        
        # Progress update
        if i % 10 == 0 or i == n_examples:
            print(f"[{i:3d}/{n_examples}] EM={em_count/i:.3f} EX={ex_count/i:.3f} Valid={valid_count/i:.3f}")
    
    # Save results
    with open(results_file, "w", newline="", encoding="utf-8") as f:
        if results:
            writer = csv.DictWriter(f, fieldnames=list(results[0].keys()))
            writer.writeheader()
            writer.writerows(results)
    
    # Close database
    conn.close()
    
    # Calculate metrics
    em_rate = em_count / n_examples
    ex_rate = ex_count / n_examples
    valid_rate = valid_count / n_examples
    median_latency = sorted(latencies)[len(latencies) // 2] if latencies else 0
    
    print(f"\n{'‚îÄ'*80}")
    print("RESULTS:")
    print(f"{'‚îÄ'*80}")
    print(f"  EM:    {em_rate:.1%} ({em_count}/{n_examples})")
    print(f"  EX:    {ex_rate:.1%} ({ex_count}/{n_examples})")
    print(f"  Valid: {valid_rate:.1%} ({valid_count}/{n_examples})")
    print(f"  Latency: {median_latency:.0f}ms")
    print(f"{'‚îÄ'*80}")
    print(f"üíæ Results saved to: {results_file.name}")
    
    return {
        "database": db_name,
        "examples": n_examples,
        "em": em_rate,
        "ex": ex_rate,
        "valid": valid_rate,
        "median_latency_ms": median_latency,
        "results_file": str(results_file)
    }

# ============================================================
# Run Evaluation on All Databases
# ============================================================

print("\n" + "=" * 80)
print("STARTING BATCH EVALUATION")
print("=" * 80)

all_results = []

for i, db_name in enumerate(DATABASES, 1):
    print(f"\n[{i}/{len(DATABASES)}] Processing {db_name}...")
    
    result = evaluate_database(db_name)
    
    if result:
        all_results.append(result)
    else:
        print(f"‚ö†Ô∏è  Skipped {db_name}")
    
    print(f"\n{'‚îÄ'*80}")

# ============================================================
# Save Summary
# ============================================================

print("\n" + "=" * 80)
print("EVALUATION COMPLETE!")
print("=" * 80)

# Save summary
summary_file = OUTPUT_DIR / "evaluation_summary.json"
with open(summary_file, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"\nüìä OVERALL SUMMARY:")
print(f"{'‚îÄ'*80}")
print(f"Databases evaluated: {len(all_results)}/{len(DATABASES)}")
print(f"Total examples: {sum(r['examples'] for r in all_results)}")

print(f"\nüìà PERFORMANCE BY DATABASE:")
print(f"{'‚îÄ'*80}")
print(f"{'Database':<35} {'Examples':<10} {'EM':<8} {'EX':<8} {'Valid':<8}")
print(f"{'‚îÄ'*80}")

for r in all_results:
    print(f"{r['database']:<35} {r['examples']:<10} {r['em']:<8.1%} {r['ex']:<8.1%} {r['valid']:<8.1%}")

print(f"{'‚îÄ'*80}")

# Calculate averages
avg_em = sum(r['em'] for r in all_results) / len(all_results)
avg_ex = sum(r['ex'] for r in all_results) / len(all_results)
avg_valid = sum(r['valid'] for r in all_results) / len(all_results)

print(f"\n{'AVERAGE':<35} {'':<10} {avg_em:<8.1%} {avg_ex:<8.1%} {avg_valid:<8.1%}")
print(f"{'='*80}")

print(f"\nüíæ Files saved:")
print(f"   Individual results: {OUTPUT_DIR}/results_*.csv")
print(f"   Summary: {summary_file}")

print("\n‚úÖ All evaluations complete!")
print("=" * 80)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu
BATCH EVALUATION ON 11 DATABASES
Model: finetuned_flant5/final_model
Device: cpu
Databases: 11

Loading fine-tuned model...
--------------------------------------------------------------------------------
‚úÖ Model loaded successfully

STARTING BATCH EVALUATION

[1/11] Processing activity_1...

EVALUATING: ACTIVITY_1
üìä Test examples: 50
‚úÖ Connected to database

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Running evaluation...
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
[ 10/50] EM=0.600 EX=1.000 Valid=1.000
[ 20/50] EM=0.500 EX=1.000 Valid=1.000
[ 30/50] EM=0.633 EX=0.967 Vali