In [3]:
import json
import sqlite3
from pathlib import Path
from collections import Counter

# ============================================================
# CONFIGURATION
# ============================================================

# Spider data path
SPIDER_PATH = "spider_data"

# 11 NEW databases for evaluation (not used in your fine-tuning)
EVALUATION_DATABASES = [
    "dorm_1",
    "allergy_1",
    "movie_1",
    "flight_1",
    "driving_school",
    "cre_Doc_Tracking_DB",
    "department_store",
    "customers_and_addresses",
    "activity_1",
    "network_2",
    "products_gen_characteristics"
]

# Output directory
OUTPUT_DIR = Path("test_visualizations")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ============================================================
# SCHEMA EXTRACTION
# ============================================================

def extract_schema_from_db(db_path, db_id):
    """Extract schema from SQLite database"""
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [row[0] for row in cursor.fetchall()]
        
        if not tables:
            return None
        
        schema_parts = [f"Database: {db_id}", "Tables:"]
        
        # Get columns for each table
        for table in tables:
            cursor.execute(f"PRAGMA table_info({table})")
            columns = cursor.fetchall()
            
            col_names = []
            for col in columns:
                col_name = col[1]
                is_pk = col[5] == 1
                if is_pk:
                    col_names.append(f"{col_name}*")
                else:
                    col_names.append(col_name)
            
            table_schema = f"- {table}({', '.join(col_names)})"
            schema_parts.append(table_schema)
        
        # Get foreign keys
        fk_lines = []
        for table in tables:
            cursor.execute(f"PRAGMA foreign_key_list({table})")
            fks = cursor.fetchall()
            
            for fk in fks:
                from_col = fk[3]
                to_table = fk[2]
                to_col = fk[4]
                fk_line = f"FK {table}.{from_col} -> {to_table}.{to_col}"
                fk_lines.append(fk_line)
        
        if fk_lines:
            schema_parts.extend(fk_lines)
        
        conn.close()
        
        return "\n".join(schema_parts)
        
    except Exception as e:
        print(f"Error extracting schema for {db_id}: {e}")
        return None

# ============================================================
# LOAD SPIDER DATA
# ============================================================

print("=" * 60)
print("CREATING TEST DATASETS FOR 11 NEW DATABASES")
print("=" * 60)
print("\n‚ÑπÔ∏è  Using databases NOT seen during training for fair evaluation")
print("-" * 60)

# Try both train_spider.json and train.json
train_files = [
    Path(SPIDER_PATH) / "train_spider.json",
    Path(SPIDER_PATH) / "train.json"
]

spider_data = None
for train_file in train_files:
    if train_file.exists():
        print(f"\nüìÇ Loading from: {train_file}")
        with open(train_file, 'r') as f:
            spider_data = json.load(f)
        print(f"‚úÖ Loaded {len(spider_data)} examples")
        break

if spider_data is None:
    print("‚ùå Error: Could not find train_spider.json or train.json")
    print("Please check your spider_data folder")
    exit(1)

# Let's first check what databases are available
print("\nüîç Checking available databases in the data...")
db_counts = Counter([ex['db_id'] for ex in spider_data])

print("\nDatabases found in training data:")
print("-" * 60)
for db_id in EVALUATION_DATABASES:
    count = db_counts.get(db_id, 0)
    status = "‚úÖ" if count > 0 else "‚ùå"
    print(f"{status} {db_id:<35} {count:>4} examples")

# ============================================================
# CREATE TEST FILES FOR EACH DATABASE
# ============================================================

print("\n" + "=" * 60)
print("Creating test files...")
print("-" * 60)

summary = []

for db_id in EVALUATION_DATABASES:
    print(f"\nüìÅ Processing {db_id}...")
    
    # Filter examples for this database
    db_examples = [ex for ex in spider_data if ex['db_id'] == db_id]
    
    if not db_examples:
        print(f"   ‚ö†Ô∏è  No examples found for {db_id} - skipping")
        continue
    
    # Get database path
    db_path = Path(SPIDER_PATH) / "database" / db_id / f"{db_id}.sqlite"
    
    if not db_path.exists():
        print(f"   ‚ö†Ô∏è  Database not found: {db_path} - skipping")
        continue
    
    # Extract schema
    schema = extract_schema_from_db(str(db_path), db_id)
    
    if not schema:
        print(f"   ‚ö†Ô∏è  Could not extract schema - skipping")
        continue
    
    # Take first 50 examples for testing (or all if less than 50)
    test_examples = db_examples[:50]
    
    # Create test JSONL
    test_data = []
    for i, example in enumerate(test_examples, 1):
        entry = {
            "id": f"{db_id}_{i}",
            "dataset": db_id,
            "db_id": db_id,
            "sqlite_path": f"spider_data/database/{db_id}/{db_id}.sqlite",
            "schema_serialized": schema,
            "question": example['question'],
            "gold_query": example['query']
        }
        test_data.append(entry)
    
    # Save as JSONL
    output_file = OUTPUT_DIR / f"test_{db_id}.jsonl"
    with open(output_file, 'w') as f:
        for entry in test_data:
            f.write(json.dumps(entry) + '\n')
    
    print(f"   ‚úÖ Created {output_file.name} with {len(test_data)} examples")
    
    summary.append({
        "database": db_id,
        "total_available": len(db_examples),
        "examples_used": len(test_data),
        "test_file": str(output_file),
        "db_file": str(db_path)
    })

# ============================================================
# SUMMARY
# ============================================================

print("\n" + "=" * 60)
print("‚úÖ TEST DATASETS CREATED")
print("=" * 60)

if not summary:
    print("\n‚ùå ERROR: No test datasets were created!")
    print("\nPossible issues:")
    print("1. Database names might be different in Spider dataset")
    print("2. SQLite files might not exist")
    print("3. Wrong file path to Spider data")
    exit(1)

print(f"\nüìä Summary:")
print(f"   Databases processed: {len(summary)}")
print(f"   Total test examples: {sum(s['examples_used'] for s in summary)}")

print("\nüìã Dataset Details:")
print("-" * 80)
print(f"{'Database':<30} {'Available':<12} {'Used':<12} {'Test File':<30}")
print("-" * 80)

for s in summary:
    print(f"{s['database']:<30} {s['total_available']:<12} {s['examples_used']:<12} {Path(s['test_file']).name:<30}")

print("-" * 80)

# Save summary as JSON
summary_file = OUTPUT_DIR / "evaluation_summary.json"
with open(summary_file, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\nüíæ Files saved:")
print(f"   Test files: {OUTPUT_DIR}/")
print(f"   Summary: {summary_file}")


CREATING TEST DATASETS FOR 11 NEW DATABASES

‚ÑπÔ∏è  Using databases NOT seen during training for fair evaluation
------------------------------------------------------------

üìÇ Loading from: spider_data/train_spider.json
‚úÖ Loaded 7000 examples

üîç Checking available databases in the data...

Databases found in training data:
------------------------------------------------------------
‚úÖ dorm_1                               100 examples
‚úÖ allergy_1                             98 examples
‚úÖ movie_1                               98 examples
‚úÖ flight_1                              96 examples
‚úÖ driving_school                        93 examples
‚úÖ cre_Doc_Tracking_DB                   90 examples
‚úÖ department_store                      88 examples
‚úÖ customers_and_addresses               88 examples
‚úÖ activity_1                            88 examples
‚úÖ network_2                             86 examples
‚úÖ products_gen_characteristics          86 examples

Creating 