In [1]:
# rebuild_train_val_from_spider.py (notebook cell)
import json, sqlite3
from pathlib import Path
from collections import defaultdict

ROOT = Path("spider_data")
TABLES_JSON     = ROOT / "tables.json"
TRAIN_SPIDER    = ROOT / "train_spider.json"
TRAIN_OTHERS    = ROOT / "train_others.json"
DEV_JSON        = ROOT / "dev.json"
DB_ROOT         = ROOT / "database"

OUT_TRAIN = Path("train_text2sql.jsonl")
OUT_VAL   = Path("val_text2sql.jsonl")
OUT_TEST  = Path("test_hospital_1.jsonl")  # NEW: separate test file

def load_json(p): return json.loads(p.read_text(encoding="utf-8"))

# ---- schema render from tables.json (robust) ----
def serialize_schema_from_tables(tables_by_db, db_id: str) -> str:
    entry = tables_by_db[db_id]
    tables = entry["table_names_original"]
    cols = entry["column_names_original"]     # list of [table_idx, col_name]
    col_table = [c[0] for c in cols]
    col_name  = [c[1] for c in cols]
    pk = set(entry.get("primary_keys", []))
    fks = entry.get("foreign_keys", [])
    
    # gather columns by table
    cols_by_table = defaultdict(list)
    for idx, (ti, name) in enumerate(zip(col_table, col_name)):
        if ti == -1:  # * or root
            continue
        cols_by_table[ti].append((idx, name))
    
    lines = [f"Database: {db_id}", "Tables:"]
    for ti, tname in enumerate(tables):
        col_str = ", ".join(
            f"{name}{'*' if idx in pk else ''}"
            for idx, name in cols_by_table.get(ti, [])
        )
        lines.append(f"- {tname}({col_str})")
    
    for src, tgt in fks:
        ti_src = col_table[src]; ti_tgt = col_table[tgt]
        if ti_src == -1 or ti_tgt == -1: 
            continue
        lines.append(f"FK {tables[ti_src]}.{col_name[src]} -> {tables[ti_tgt]}.{col_name[tgt]}")
    
    return "\n".join(lines)

def ensure_semicolon(sql: str) -> str:
    sql = (sql or "").strip()
    return sql if sql.endswith(";") else sql + ";"

# MODIFIED: Added exclude_db parameter
def pick(rows, exclude_db=None):
    out = []
    excluded_count = 0
    for r in rows:
        q = r.get("question")
        s = r.get("query") or r.get("sql")
        dbid = r.get("db_id")
        
        # Skip excluded database
        if exclude_db and dbid == exclude_db:
            excluded_count += 1
            continue
            
        if q and s and dbid:
            out.append({"question": q.strip(), "gold_query": ensure_semicolon(s.strip()), "db_id": dbid})
    
    if excluded_count > 0:
        print(f"[info] Excluded {excluded_count} examples from '{exclude_db}'")
    
    return out

# NEW: Function to pick ONLY hospital_1
def pick_only_hospital(rows):
    out = []
    for r in rows:
        q = r.get("question")
        s = r.get("query") or r.get("sql")
        dbid = r.get("db_id")
        
        # Only keep hospital_1
        if dbid == "hospital_1" and q and s:
            out.append({"question": q.strip(), "gold_query": ensure_semicolon(s.strip()), "db_id": dbid})
    
    return out

tables = load_json(TABLES_JSON)
tables_by_db = {e["db_id"]: e for e in tables}

# Build train/val WITHOUT hospital_1
train = pick(load_json(TRAIN_SPIDER), exclude_db="hospital_1") + pick(load_json(TRAIN_OTHERS), exclude_db="hospital_1")
val   = pick(load_json(DEV_JSON), exclude_db="hospital_1")

# Build test set with ONLY hospital_1
test  = pick_only_hospital(load_json(TRAIN_SPIDER)) + \
        pick_only_hospital(load_json(TRAIN_OTHERS)) + \
        pick_only_hospital(load_json(DEV_JSON))

# attach schema text + optional sqlite path if present
def attach_schema(rows):
    out = []
    miss = 0
    for i, r in enumerate(rows, 1):
        dbid = r["db_id"]
        if dbid not in tables_by_db:
            miss += 1
            continue
        schema_text = serialize_schema_from_tables(tables_by_db, dbid)
        # sqlite is not required for training; we add if exists for later EX
        sqlite_path = DB_ROOT / dbid / f"{dbid}.sqlite"
        out.append({
            "id": f"{dbid}_{i}",
            "dataset": dbid,
            "db_id": dbid,
            "sqlite_path": str(sqlite_path) if sqlite_path.exists() else "",
            "schema_serialized": schema_text,
            "question": r["question"],
            "gold_query": r["gold_query"],
        })
    if miss:
        print(f"[warn] skipped {miss} rows with missing tables.json entries")
    return out

train_rows = attach_schema(train)
val_rows   = attach_schema(val)
test_rows  = attach_schema(test)

def dump_jsonl(path, rows):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

dump_jsonl(OUT_TRAIN, train_rows)
dump_jsonl(OUT_VAL,   val_rows)
dump_jsonl(OUT_TEST,  test_rows)

print(f"‚úÖ train (excluding hospital_1): {len(train_rows)}  ‚Üí {OUT_TRAIN}")
print(f"‚úÖ val (excluding hospital_1):   {len(val_rows)}    ‚Üí {OUT_VAL}")
print(f"‚úÖ test (ONLY hospital_1):       {len(test_rows)}   ‚Üí {OUT_TEST}")

[info] Excluded 100 examples from 'hospital_1'
‚úÖ train (excluding hospital_1): 8559  ‚Üí train_text2sql.jsonl
‚úÖ val (excluding hospital_1):   1034    ‚Üí val_text2sql.jsonl
‚úÖ test (ONLY hospital_1):       100   ‚Üí test_hospital_1.jsonl


In [1]:
import json
from collections import Counter
import os

# Path to your spider data - adjust if needed
spider_path = "spider_data"

def count_examples_per_database(json_file):
    """Count examples for each database in a Spider JSON file"""
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # Extract database names from each example
    db_names = [example['db_id'] for example in data]
    
    # Count occurrences
    return Counter(db_names)

def main():
    # Check both train and dev files
    files_to_check = ['train_spider.json', 'train.json', 'dev.json']
    
    print("=" * 60)
    print("SPIDER DATASET - Examples per Database")
    print("=" * 60)
    
    all_counts = Counter()
    
    for filename in files_to_check:
        filepath = os.path.join(spider_path, filename)
        
        if os.path.exists(filepath):
            print(f"\nüìÅ Found: {filename}")
            counts = count_examples_per_database(filepath)
            all_counts.update(counts)
        else:
            print(f"\n‚ùå Not found: {filename}")
    
    if all_counts:
        print("\n" + "=" * 60)
        print("COMBINED TOTALS (sorted by count)")
        print("=" * 60)
        
        # Sort by count (descending)
        sorted_dbs = sorted(all_counts.items(), key=lambda x: x[1], reverse=True)
        
        print(f"\n{'Rank':<6} {'Database':<30} {'Count':<10}")
        print("-" * 60)
        
        for rank, (db_name, count) in enumerate(sorted_dbs, 1):
            print(f"{rank:<6} {db_name:<30} {count:<10}")
        
        # Show top 3 recommendations
        print("\n" + "=" * 60)
        print("üéØ TOP 3 DATABASES FOR FINE-TUNING:")
        print("=" * 60)
        
        for i, (db_name, count) in enumerate(sorted_dbs[:3], 1):
            print(f"{i}. {db_name}: {count} examples")
        
        # Also show top 10 for quick reference
        print("\n" + "=" * 60)
        print("üìä TOP 10 DATABASES:")
        print("=" * 60)
        for i, (db_name, count) in enumerate(sorted_dbs[:10], 1):
            print(f"{i:2d}. {db_name:25s} - {count:4d} examples")
    else:
        print("\n‚ùå No data files found. Please check your file paths.")
        print(f"   Looking in: {spider_path}")

if __name__ == "__main__":
    main()

SPIDER DATASET - Examples per Database

üìÅ Found: train_spider.json

‚ùå Not found: train.json

üìÅ Found: dev.json

COMBINED TOTALS (sorted by count)

Rank   Database                       Count     
------------------------------------------------------------
1      college_2                      170       
2      college_1                      164       
3      hr_1                           124       
4      world_1                        120       
5      store_1                        112       
6      soccer_2                       106       
7      bike_1                         104       
8      music_1                        100       
9      hospital_1                     100       
10     music_2                        100       
11     dorm_1                         100       
12     allergy_1                      98        
13     movie_1                        98        
14     flight_1                       96        
15     driving_school                 93        


In [6]:
import json
import sqlite3
import os
from collections import Counter

def get_schema_serialized(db_path, db_id):
    """Extract schema information from SQLite database"""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0] for row in cursor.fetchall()]
    
    schema_parts = [f"Database: {db_id}", "Tables:"]
    
    table_schemas = []
    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()
        
        col_names = []
        for col in columns:
            col_name = col[1]
            is_pk = col[5] == 1
            if is_pk:
                col_names.append(f"{col_name}*")
            else:
                col_names.append(col_name)
        
        table_schema = f"- {table}({', '.join(col_names)})"
        table_schemas.append(table_schema)
    
    schema_parts.extend(table_schemas)
    
    fk_lines = []
    for table in tables:
        cursor.execute(f"PRAGMA foreign_key_list({table})")
        fks = cursor.fetchall()
        
        for fk in fks:
            from_table = table
            from_col = fk[3]
            to_table = fk[2]
            to_col = fk[4]
            fk_line = f"FK {from_table}.{from_col} -> {to_table}.{to_col}"
            fk_lines.append(fk_line)
    
    if fk_lines:
        schema_parts.extend(fk_lines)
    
    conn.close()
    return "\n".join(schema_parts)

# Load Spider data
with open('spider_data/train_spider.json', 'r') as f:
    train_data = json.load(f)

# Include all 3 datasets
datasets_to_use = ['hr_1', 'store_1', 'college_2']
combined_dataset = []
counter = {db: 1 for db in datasets_to_use}

print(f"Processing {len(train_data)} examples...")

for example in train_data:
    db_id = example['db_id']
    
    if db_id not in datasets_to_use:
        continue
    
    sqlite_path = f"spider_data/database/{db_id}/{db_id}.sqlite"
    
    if not os.path.exists(sqlite_path):
        print(f"‚ö†Ô∏è  Skipping {db_id}: database not found")
        continue
    
    try:
        schema_serialized = get_schema_serialized(sqlite_path, db_id)
    except Exception as e:
        print(f"‚ö†Ô∏è  Error with {db_id}: {e}")
        continue
    
    entry = {
        "id": f"{db_id}_{counter[db_id]}",
        "dataset": db_id,
        "db_id": db_id,
        "sqlite_path": sqlite_path,
        "schema_serialized": schema_serialized,
        "question": example['question'],
        "gold_query": example['query']
    }
    
    combined_dataset.append(entry)
    counter[db_id] += 1

# Save
with open('Data_for_demo_v2.jsonl', 'w') as f:
    for entry in combined_dataset:
        f.write(json.dumps(entry) + '\n')

print("\n" + "="*60)
print("‚úÖ Dataset created!")
print(f"Total examples: {len(combined_dataset)}")
print("\nBreakdown:")
for db_id in datasets_to_use:
    count = counter[db_id] - 1
    print(f"  {db_id}: {count} examples")

Processing 7000 examples...

‚úÖ Dataset created!
Total examples: 406

Breakdown:
  hr_1: 124 examples
  store_1: 112 examples
  college_2: 170 examples


In [4]:
# ============================================================
# CHECK: Verify world_1 database exists
# ============================================================

import os

# Check if world_1 database exists
world_1_path = "spider_data/database/world_1/world_1.sqlite"

print("Checking for world_1 database...")
print(f"Looking for: {world_1_path}")

if os.path.exists(world_1_path):
    print("‚úÖ world_1.sqlite EXISTS")
    
    # Check if it has tables
    import sqlite3
    conn = sqlite3.connect(world_1_path)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0] for row in cursor.fetchall()]
    conn.close()
    
    print(f"Tables in world_1: {tables}")
    
else:
    print("‚ùå world_1.sqlite NOT FOUND")
    print("\nLet's check what databases DO exist:")
    
    db_folder = "spider_data/database"
    if os.path.exists(db_folder):
        databases = [d for d in os.listdir(db_folder) if os.path.isdir(os.path.join(db_folder, d))]
        print(f"\nAvailable databases ({len(databases)}):")
        for db in sorted(databases)[:20]:  # Show first 20
            print(f"  - {db}")
    else:
        print("‚ùå spider_data/database folder not found!")

Checking for world_1 database...
Looking for: spider_data/database/world_1/world_1.sqlite
‚úÖ world_1.sqlite EXISTS
Tables in world_1: ['city', 'sqlite_sequence', 'country', 'countrylanguage']


In [5]:
# ============================================================
# CHECK: Find world_1 examples in Spider data
# ============================================================

import json

# Load Spider training data
with open('spider_data/train_spider.json', 'r') as f:
    train_data = json.load(f)

# Count world_1 examples
world_1_examples = [ex for ex in train_data if ex['db_id'] == 'world_1']

print(f"Total examples in train_spider.json: {len(train_data)}")
print(f"world_1 examples found: {len(world_1_examples)}")

if len(world_1_examples) > 0:
    print("\n‚úÖ world_1 examples EXIST in training data")
    print("\nSample world_1 example:")
    print(f"  Question: {world_1_examples[0]['question']}")
    print(f"  Query: {world_1_examples[0]['query']}")
else:
    print("\n‚ùå world_1 examples NOT FOUND in train_spider.json")
    print("\nLet me check what db_ids ARE in the data:")
    
    from collections import Counter
    db_counts = Counter([ex['db_id'] for ex in train_data])
    
    print(f"\nTop 20 databases by count:")
    for db_id, count in db_counts.most_common(20):
        print(f"  {db_id}: {count}")

Total examples in train_spider.json: 7000
world_1 examples found: 0

‚ùå world_1 examples NOT FOUND in train_spider.json

Let me check what db_ids ARE in the data:

Top 20 databases by count:
  college_2: 170
  college_1: 164
  hr_1: 124
  store_1: 112
  soccer_2: 106
  bike_1: 104
  music_1: 100
  hospital_1: 100
  music_2: 100
  dorm_1: 100
  allergy_1: 98
  movie_1: 98
  flight_1: 96
  driving_school: 93
  cre_Doc_Tracking_DB: 90
  department_store: 88
  customers_and_addresses: 88
  activity_1: 88
  network_2: 86
  products_gen_characteristics: 86
