# Text-to-SQL V4 - FIXED Schema Handling

**Root Cause Analysis:**
- V2/V3 failed because schema wasn't being extracted properly
- Model only saw `DATABASE db_name` instead of actual tables/columns
- Without schema context, model can't learn table/column names

**V4 Fixes:**
1. Diagnose dataset structure first
2. Robust schema extraction with fallbacks
3. Schema included directly in Spider examples
4. Fixed regex escaping

---

In [None]:
# Cell 1: Install
!pip install -q transformers>=4.35.0 datasets>=2.14.0 accelerate>=0.24.0
!pip install -q torch sentencepiece pandas numpy tqdm

In [None]:
# Cell 2: Imports
import torch
import numpy as np
import pandas as pd
import json
import re
import warnings
from collections import defaultdict

warnings.filterwarnings('ignore')

print("=" * 60)
print("TEXT-TO-SQL V4 - FIXED SCHEMA")
print("=" * 60)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    GPU_NAME = torch.cuda.get_device_name(0)
    GPU_MEM = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {GPU_NAME} ({GPU_MEM:.1f} GB)")
    MODEL_NAME = "google-t5/t5-base"
else:
    MODEL_NAME = "google-t5/t5-small"
    print("WARNING: No GPU!")

print(f"Model: {MODEL_NAME}")
print("=" * 60)

In [None]:
# Cell 3: Load Dataset & DIAGNOSE STRUCTURE
from datasets import load_dataset

print("Loading Spider dataset...")
dataset = load_dataset("xlangai/spider")

print(f"\nTrain: {len(dataset['train'])} | Validation: {len(dataset['validation'])}")

# CRITICAL: Show what's actually in the dataset
print("\n" + "=" * 60)
print("DATASET STRUCTURE DIAGNOSIS")
print("=" * 60)

sample = dataset['train'][0]
print(f"\nAvailable fields: {list(sample.keys())}")

print("\n--- Sample Values ---")
for key in sample.keys():
    val = sample[key]
    if isinstance(val, str):
        print(f"{key}: {val[:100]}..." if len(val) > 100 else f"{key}: {val}")
    elif isinstance(val, list) and len(val) > 0:
        print(f"{key}: {type(val).__name__}[{len(val)}] = {val[:3]}...")
    else:
        print(f"{key}: {val}")

print("\n" + "=" * 60)

In [None]:
# Cell 4: ROBUST Schema Extraction
def extract_schema(example):
    """
    Extract schema from Spider dataset with multiple fallback strategies.
    """
    schema_parts = []
    
    # Strategy 1: Use db_table_names + db_column_names (if available)
    table_names = example.get('db_table_names', [])
    column_names = example.get('db_column_names', [])
    column_types = example.get('db_column_types', [])
    
    if table_names and column_names:
        table_cols = defaultdict(list)
        
        for idx, col_info in enumerate(column_names):
            # Handle different formats: [table_idx, col_name] or {'table_id': x, 'column_name': y}
            if isinstance(col_info, (list, tuple)) and len(col_info) >= 2:
                table_idx, col_name = col_info[0], col_info[1]
            elif isinstance(col_info, dict):
                table_idx = col_info.get('table_id', -1)
                col_name = col_info.get('column_name', '')
            else:
                continue
            
            if table_idx == -1 or table_idx >= len(table_names):
                continue
            
            table_name = str(table_names[table_idx]).lower()
            col_str = str(col_name).lower()
            table_cols[table_name].append(col_str)
        
        if table_cols:
            for tbl, cols in table_cols.items():
                schema_parts.append(f"{tbl}({', '.join(cols)})")
    
    # Strategy 2: Use 'schema' field directly if present
    if not schema_parts and 'schema' in example:
        schema_str = str(example['schema'])
        if len(schema_str) > 10:  # Non-trivial schema
            return schema_str[:500]  # Truncate if too long
    
    # Strategy 3: Extract from query (table names from FROM/JOIN clauses)
    if not schema_parts:
        query = str(example.get('query', '')).lower()
        # Extract tables from FROM and JOIN
        from_match = re.findall(r'from\s+(\w+)', query)
        join_match = re.findall(r'join\s+(\w+)', query)
        tables = set(from_match + join_match)
        if tables:
            schema_parts = [f"{t}(*)" for t in tables]
    
    # Strategy 4: Just use db_id
    if not schema_parts:
        db_id = example.get('db_id', 'database')
        return f"database: {db_id}"
    
    return " | ".join(schema_parts)

# Test on first few examples
print("Testing schema extraction:")
for i in range(3):
    ex = dataset['train'][i]
    schema = extract_schema(ex)
    print(f"\nExample {i}:")
    print(f"  Question: {ex['question'][:60]}...")
    print(f"  Schema: {schema[:80]}..." if len(schema) > 80 else f"  Schema: {schema}")

In [None]:
# Cell 5: SQL Normalization (FIXED - no regex issues)
def normalize_sql(sql):
    """
    Normalize SQL for comparison. Fixed regex escaping.
    """
    if not sql:
        return ""
    
    sql = str(sql).strip().lower()
    
    # Normalize whitespace
    sql = ' '.join(sql.split())
    
    # Normalize operators (simple string replacement, not regex)
    for op in ['>=', '<=', '!=', '<>', '=', '>', '<']:
        sql = sql.replace(op, f' {op} ')
    
    # Normalize commas
    sql = sql.replace(',', ', ')
    
    # Remove extra spaces
    sql = ' '.join(sql.split())
    
    # Remove trailing semicolon
    sql = sql.rstrip(';').strip()
    
    return sql

# Test
test_cases = [
    "SELECT COUNT(*)  FROM  students   WHERE gpa>3.5;",
    "select * from users where age >= 18",
    "SELECT name,age FROM people"
]

print("SQL Normalization Test:")
for sql in test_cases:
    print(f"  {sql}")
    print(f"  → {normalize_sql(sql)}")
    print()

In [None]:
# Cell 6: Preprocessing with VERIFIED Schema
def preprocess_v4(example):
    question = str(example.get('question', '')).strip()
    sql = str(example.get('query', '')).strip()
    schema = extract_schema(example)
    
    # Format: explicit instruction with schema
    input_text = f"translate to SQL: {question} | schema: {schema}"
    target_text = normalize_sql(sql)
    
    return {
        "input_text": input_text,
        "target_text": target_text
    }

print("Preprocessing...")
processed = dataset.map(preprocess_v4, num_proc=4)

# Verify the preprocessing worked
print("\n" + "=" * 60)
print("PREPROCESSING VERIFICATION")
print("=" * 60)

for i in range(3):
    ex = processed['train'][i]
    print(f"\nExample {i}:")
    print(f"Input: {ex['input_text'][:150]}...")
    print(f"Target: {ex['target_text']}")

# Check schema coverage
has_schema = sum(1 for ex in processed['train'] if 'schema:' in ex['input_text'] and len(ex['input_text'].split('schema:')[1]) > 20)
print(f"\nExamples with meaningful schema: {has_schema}/{len(processed['train'])} ({100*has_schema/len(processed['train']):.1f}%)")

In [None]:
# Cell 7: Tokenization
from transformers import AutoTokenizer

print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

MAX_INPUT = 512
MAX_TARGET = 256

def tokenize(examples):
    inputs = tokenizer(
        examples["input_text"],
        max_length=MAX_INPUT,
        truncation=True,
        padding=False
    )
    targets = tokenizer(
        text_target=examples["target_text"],
        max_length=MAX_TARGET,
        truncation=True,
        padding=False
    )
    inputs["labels"] = targets["input_ids"]
    return inputs

print("Tokenizing...")
tokenized = processed.map(
    tokenize,
    batched=True,
    num_proc=4,
    remove_columns=processed['train'].column_names
)

print(f"Done. Columns: {tokenized['train'].column_names}")

In [None]:
# Cell 8: Load Model
from transformers import AutoModelForSeq2SeqLM

print(f"Loading model: {MODEL_NAME}")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.gradient_checkpointing_enable()

print(f"Parameters: {model.num_parameters():,}")

In [None]:
# Cell 9: Training Configuration
from transformers import (
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    padding=True
)

# V4 Settings - balanced approach
training_args = Seq2SeqTrainingArguments(
    output_dir="./t2sql_v4",
    
    num_train_epochs=25,           # More epochs for convergence
    learning_rate=2e-4,            # Middle ground
    warmup_ratio=0.06,
    lr_scheduler_type="cosine",
    
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,  # Effective batch 32
    
    weight_decay=0.01,
    
    fp16=torch.cuda.is_available(),
    gradient_checkpointing=True,
    label_smoothing_factor=0.1,
    
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    
    predict_with_generate=True,
    generation_max_length=MAX_TARGET,
    generation_num_beams=4,
    
    logging_steps=50,
    report_to="none",
    dataloader_num_workers=2,
    seed=42,
)

print("V4 Configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  LR: {training_args.learning_rate}")
print(f"  Effective batch: 32")

In [None]:
# Cell 10: Metrics
VOCAB_SIZE = len(tokenizer)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    
    predictions = np.clip(predictions, 0, VOCAB_SIZE - 1)
    
    try:
        pred_texts = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    except:
        return {"exact_match": 0.0, "normalized_match": 0.0}
    
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    labels = np.clip(labels, 0, VOCAB_SIZE - 1)
    
    try:
        label_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
    except:
        return {"exact_match": 0.0, "normalized_match": 0.0}
    
    exact = 0
    normalized = 0
    total = len(pred_texts)
    
    for pred, label in zip(pred_texts, label_texts):
        if pred.strip() == label.strip():
            exact += 1
        if normalize_sql(pred) == normalize_sql(label):
            normalized += 1
    
    return {
        "exact_match": exact / total if total > 0 else 0.0,
        "normalized_match": normalized / total if total > 0 else 0.0
    }

print("Metrics ready.")

In [None]:
# Cell 11: Initialize Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print(f"Trainer ready. Train: {len(tokenized['train'])} | Eval: {len(tokenized['validation'])}")

In [None]:
# Cell 12: Verification
print("Final verification...")

test_batch = [tokenized['train'][i] for i in range(2)]
collated = data_collator(test_batch)
print("✓ Collator OK")

model.eval()
with torch.no_grad():
    out = model(**{k: v.to(model.device) for k, v in collated.items()})
print(f"✓ Forward OK (loss: {out.loss.item():.2f})")

print("\n" + "=" * 60)
print("V4 READY TO TRAIN")
print("=" * 60)

In [None]:
# Cell 13: TRAIN
print("=" * 60)
print("STARTING V4 TRAINING")
print("=" * 60)
print(f"Model: {MODEL_NAME}")
print(f"Epochs: 25 | LR: 2e-4")
print("=" * 60)

if torch.cuda.is_available():
    torch.cuda.empty_cache()

result = trainer.train()

print("\n" + "=" * 60)
print("V4 TRAINING COMPLETE!")
print("=" * 60)
print(f"Train loss: {result.training_loss:.4f}")
print(f"Time: {result.metrics['train_runtime']/3600:.2f} hours")

In [None]:
# Cell 14: Evaluate
print("Evaluating...\n")

eval_results = trainer.evaluate()

print("=" * 60)
print("V4 RESULTS")
print("=" * 60)
print(f"Eval Loss: {eval_results['eval_loss']:.4f}")
print(f"Exact Match: {eval_results['eval_exact_match']*100:.2f}%")
print(f"Normalized Match: {eval_results['eval_normalized_match']*100:.2f}%")
print("=" * 60)

nm = eval_results['eval_normalized_match'] * 100
if nm >= 40:
    grade = "EXCELLENT"
elif nm >= 30:
    grade = "GOOD"
elif nm >= 20:
    grade = "ACCEPTABLE"
else:
    grade = "NEEDS WORK"
print(f"\nGrade: {grade}")

In [None]:
# Cell 15: Save Model
OUTPUT = "./t2sql_final_v4"

print(f"Saving to {OUTPUT}...")
trainer.save_model(OUTPUT)
tokenizer.save_pretrained(OUTPUT)

report = {
    "version": "v4_fixed_schema",
    "model": MODEL_NAME,
    "epochs": 25,
    "learning_rate": "2e-4",
    "train_loss": result.training_loss,
    "eval_loss": eval_results['eval_loss'],
    "exact_match_pct": eval_results['eval_exact_match'] * 100,
    "normalized_match_pct": eval_results['eval_normalized_match'] * 100,
    "training_hours": result.metrics['train_runtime'] / 3600
}

with open("report_v4.json", "w") as f:
    json.dump(report, f, indent=2)

print("Saved!")

In [None]:
# Cell 16: Test Predictions
from transformers import pipeline

print("Testing V4 model...\n")

gen = pipeline(
    "text2text-generation",
    model=OUTPUT,
    device=0 if torch.cuda.is_available() else -1
)

def predict(q, schema):
    inp = f"translate to SQL: {q} | schema: {schema}"
    out = gen(inp, max_length=256, num_beams=4)
    return out[0]['generated_text']

tests = [
    ("How many students are there?", "students(id, name, age, gpa)"),
    ("Find students with GPA above 3.5", "students(id, name, gpa)"),
    ("List departments with average salary", "employees(id, name, dept, salary)"),
    ("Count employees per department", "employees(id, name, department)"),
]

for q, s in tests:
    sql = predict(q, s)
    print(f"Q: {q}")
    print(f"Schema: {s}")
    print(f"SQL: {sql}")
    print()

In [None]:
# Cell 17: Zip Model
import shutil

print("Zipping model...")
shutil.make_archive("t2sql_v4_model", "zip", ".", "t2sql_final_v4")
print("Created: t2sql_v4_model.zip")

print("\n" + "=" * 60)
print("V4 FINAL REPORT")
print("=" * 60)
print(json.dumps(report, indent=2))
print("=" * 60)
print("\nDownload: t2sql_v4_model.zip")