# Text-to-SQL V5 - STABLE Training

**V4 Issue:** Loss exploded (112 → NaN)

**V5 Fixes:**
- Lower LR: 5e-5 (was 2e-4)
- Gradient clipping: max_grad_norm=1.0
- Schema truncation: max 200 chars
- FP32 training (disable FP16 for stability)
- Sanity checks on data

**SETUP:**
1. Add dataset: `jeromeblanchet/yale-universitys-spider-10-nlp-dataset`
2. Enable GPU
3. Run all

---

In [None]:
# Cell 1: Install
!pip install -q transformers>=4.35.0 datasets>=2.14.0 accelerate>=0.24.0
!pip install -q torch sentencepiece pandas numpy tqdm

In [None]:
# Cell 2: Setup
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Fix tokenizer warning

import json
import torch
import numpy as np
import warnings
import glob
from collections import defaultdict

warnings.filterwarnings('ignore')

print("=" * 60)
print("TEXT-TO-SQL V5 - STABLE")
print("=" * 60)

# Find Spider dataset
KAGGLE_INPUT = "/kaggle/input"
SPIDER_PATH = None

if os.path.exists(KAGGLE_INPUT):
    for folder in os.listdir(KAGGLE_INPUT):
        if "spider" in folder.lower():
            SPIDER_PATH = os.path.join(KAGGLE_INPUT, folder)
            break

if SPIDER_PATH is None:
    raise FileNotFoundError("Add Spider dataset: jeromeblanchet/yale-universitys-spider-10-nlp-dataset")

print(f"Spider path: {SPIDER_PATH}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    MODEL_NAME = "google-t5/t5-base"
else:
    MODEL_NAME = "google-t5/t5-small"

print(f"Model: {MODEL_NAME}")
print("=" * 60)

In [None]:
# Cell 3: Load Schema
tables_files = glob.glob(f"{SPIDER_PATH}/**/tables.json", recursive=True)
if not tables_files:
    tables_files = [f"{SPIDER_PATH}/tables.json"]

with open(tables_files[0]) as f:
    tables_data = json.load(f)

SCHEMA_LOOKUP = {db["db_id"]: db for db in tables_data}
print(f"Loaded {len(SCHEMA_LOOKUP)} database schemas")

In [None]:
# Cell 4: Load Train/Dev
train_files = [f for f in glob.glob(f"{SPIDER_PATH}/**/train*.json", recursive=True) if "tables" not in f.lower()]
dev_files = [f for f in glob.glob(f"{SPIDER_PATH}/**/dev*.json", recursive=True) if "tables" not in f.lower()]

with open(train_files[0]) as f:
    train_data = json.load(f)
with open(dev_files[0]) as f:
    dev_data = json.load(f)

print(f"Train: {len(train_data)} | Dev: {len(dev_data)}")

In [None]:
# Cell 5: Schema Serialization (TRUNCATED for stability)
MAX_SCHEMA_LEN = 200  # Limit schema length

def serialize_schema(db_id):
    """Convert schema to compact string, truncated to MAX_SCHEMA_LEN."""
    if db_id not in SCHEMA_LOOKUP:
        return db_id
    
    schema = SCHEMA_LOOKUP[db_id]
    table_names = schema.get("table_names", [])
    column_names = schema.get("column_names", [])
    
    table_cols = defaultdict(list)
    
    for col_info in column_names:
        if not isinstance(col_info, list) or len(col_info) < 2:
            continue
        table_idx, col_name = col_info[0], col_info[1]
        if table_idx < 0 or table_idx >= len(table_names):
            continue
        
        table_name = table_names[table_idx].lower().replace(" ", "_")
        col_name = col_name.lower().replace(" ", "_")
        table_cols[table_name].append(col_name)
    
    if not table_cols:
        return db_id
    
    # Build compact schema
    parts = []
    for tbl, cols in table_cols.items():
        # Limit columns per table
        cols_str = ", ".join(cols[:8])  # Max 8 columns per table
        if len(cols) > 8:
            cols_str += "..."
        parts.append(f"{tbl}({cols_str})")
    
    result = " | ".join(parts)
    
    # Truncate if too long
    if len(result) > MAX_SCHEMA_LEN:
        result = result[:MAX_SCHEMA_LEN] + "..."
    
    return result

# Test
for i in range(2):
    db_id = train_data[i]["db_id"]
    schema = serialize_schema(db_id)
    print(f"{db_id}: {schema[:80]}..." if len(schema) > 80 else f"{db_id}: {schema}")

In [None]:
# Cell 6: SQL Normalization
def normalize_sql(sql):
    if not sql:
        return ""
    sql = str(sql).strip().lower()
    sql = ' '.join(sql.split())
    sql = sql.rstrip(';').strip()
    return sql

print("SQL normalization ready.")

In [None]:
# Cell 7: Create Dataset with VALIDATION
from datasets import Dataset

def process_examples(data_list):
    processed = []
    skipped = 0
    
    for item in data_list:
        question = str(item.get("question", "")).strip()
        query = str(item.get("query", "")).strip()
        db_id = item.get("db_id", "")
        
        # Skip invalid examples
        if not question or not query:
            skipped += 1
            continue
        
        schema = serialize_schema(db_id)
        
        input_text = f"translate to SQL: {question} | schema: {schema}"
        target_text = normalize_sql(query)
        
        # Validate lengths
        if len(input_text) > 1000 or len(target_text) > 500:
            skipped += 1
            continue
        
        processed.append({
            "input_text": input_text,
            "target_text": target_text
        })
    
    if skipped > 0:
        print(f"  Skipped {skipped} invalid examples")
    
    return processed

print("Processing train...")
train_processed = process_examples(train_data)
print("Processing dev...")
dev_processed = process_examples(dev_data)

train_dataset = Dataset.from_list(train_processed)
dev_dataset = Dataset.from_list(dev_processed)

print(f"\nFinal: Train {len(train_dataset)} | Dev {len(dev_dataset)}")

# Show samples
print("\n--- Samples ---")
for i in range(2):
    print(f"Input: {train_dataset[i]['input_text'][:100]}...")
    print(f"Target: {train_dataset[i]['target_text']}")
    print()

In [None]:
# Cell 8: Tokenization with length checks
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

MAX_INPUT = 384   # Reduced from 512 for safety
MAX_TARGET = 128  # Reduced from 256

def tokenize(examples):
    inputs = tokenizer(
        examples["input_text"],
        max_length=MAX_INPUT,
        truncation=True,
        padding=False
    )
    targets = tokenizer(
        text_target=examples["target_text"],
        max_length=MAX_TARGET,
        truncation=True,
        padding=False
    )
    inputs["labels"] = targets["input_ids"]
    return inputs

print("Tokenizing...")
train_tokenized = train_dataset.map(tokenize, batched=True, remove_columns=train_dataset.column_names)
dev_tokenized = dev_dataset.map(tokenize, batched=True, remove_columns=dev_dataset.column_names)

print(f"Done. Train: {len(train_tokenized)} | Dev: {len(dev_tokenized)}")

# Check token lengths
sample_lens = [len(train_tokenized[i]['input_ids']) for i in range(min(100, len(train_tokenized)))]
print(f"Avg input length: {np.mean(sample_lens):.0f} tokens")

In [None]:
# Cell 9: Load Model
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.gradient_checkpointing_enable()
print(f"Model: {MODEL_NAME} ({model.num_parameters():,} params)")

In [None]:
# Cell 10: STABLE Training Config
from transformers import (
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    padding=True
)

# V5 STABLE CONFIG
training_args = Seq2SeqTrainingArguments(
    output_dir="./t2sql_v5",
    
    # === STABILITY FIXES ===
    num_train_epochs=15,
    learning_rate=5e-5,            # LOWER (was 2e-4)
    warmup_ratio=0.1,              # More warmup
    max_grad_norm=1.0,             # Gradient clipping
    lr_scheduler_type="linear",    # Simple linear decay
    
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,  # Effective batch 32
    
    weight_decay=0.01,
    fp16=False,                     # DISABLE FP16 for stability
    gradient_checkpointing=True,
    label_smoothing_factor=0.0,     # No label smoothing
    
    eval_strategy="steps",
    eval_steps=250,
    save_strategy="steps",
    save_steps=250,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    
    predict_with_generate=True,
    generation_max_length=MAX_TARGET,
    generation_num_beams=4,
    
    logging_steps=50,
    report_to="none",
    dataloader_num_workers=0,       # Disable multiprocessing
    seed=42,
)

print("V5 STABLE CONFIG:")
print(f"  LR: {training_args.learning_rate} (was 2e-4)")
print(f"  Grad clip: {training_args.max_grad_norm}")
print(f"  FP16: {training_args.fp16} (disabled for stability)")
print(f"  Epochs: {training_args.num_train_epochs}")

In [None]:
# Cell 11: Metrics
VOCAB_SIZE = len(tokenizer)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    
    # Safety: clip to valid range
    predictions = np.clip(predictions, 0, VOCAB_SIZE - 1)
    
    try:
        pred_texts = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        labels = np.clip(labels, 0, VOCAB_SIZE - 1)
        label_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
    except Exception as e:
        print(f"Metric error: {e}")
        return {"exact_match": 0.0, "normalized_match": 0.0}
    
    exact = sum(1 for p, l in zip(pred_texts, label_texts) if p.strip() == l.strip())
    normalized = sum(1 for p, l in zip(pred_texts, label_texts) if normalize_sql(p) == normalize_sql(l))
    total = len(pred_texts)
    
    return {
        "exact_match": exact / total if total > 0 else 0.0,
        "normalized_match": normalized / total if total > 0 else 0.0
    }

print("Metrics ready.")

In [None]:
# Cell 12: Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=dev_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print(f"Trainer ready. Train: {len(train_tokenized)} | Dev: {len(dev_tokenized)}")

In [None]:
# Cell 13: Sanity Check - Test one batch
print("Sanity check...")

test_batch = [train_tokenized[i] for i in range(4)]
collated = data_collator(test_batch)

model.eval()
with torch.no_grad():
    device = next(model.parameters()).device
    batch = {k: v.to(device) for k, v in collated.items()}
    out = model(**batch)
    initial_loss = out.loss.item()

print(f"Initial loss: {initial_loss:.2f}")

if initial_loss > 20:
    print("WARNING: High initial loss!")
    print("Checking data...")
    for i in range(2):
        inp = tokenizer.decode(test_batch[i]['input_ids'][:50])
        lbl = tokenizer.decode([t for t in test_batch[i]['labels'][:30] if t != -100])
        print(f"  Input: {inp}...")
        print(f"  Label: {lbl}")
elif initial_loss < 10:
    print("✓ Loss looks normal. Ready to train.")

model.train()

In [None]:
# Cell 14: TRAIN
print("=" * 60)
print("STARTING V5 STABLE TRAINING")
print("=" * 60)
print(f"Model: {MODEL_NAME}")
print(f"LR: 5e-5 | Epochs: 15 | Grad clip: 1.0")
print(f"FP16: OFF (stability)")
print("=" * 60)

if torch.cuda.is_available():
    torch.cuda.empty_cache()

result = trainer.train()

print("\n" + "=" * 60)
print("V5 TRAINING COMPLETE!")
print("=" * 60)
print(f"Final train loss: {result.training_loss:.4f}")
print(f"Time: {result.metrics['train_runtime']/3600:.2f} hours")

In [None]:
# Cell 15: Evaluate
print("Evaluating...")
eval_results = trainer.evaluate()

print("\n" + "=" * 60)
print("V5 RESULTS")
print("=" * 60)
print(f"Eval Loss: {eval_results['eval_loss']:.4f}")
print(f"Exact Match: {eval_results['eval_exact_match']*100:.2f}%")
print(f"Normalized Match: {eval_results['eval_normalized_match']*100:.2f}%")
print("=" * 60)

In [None]:
# Cell 16: Save
OUTPUT = "./t2sql_final_v5"

trainer.save_model(OUTPUT)
tokenizer.save_pretrained(OUTPUT)

report = {
    "version": "v5_stable",
    "model": MODEL_NAME,
    "train_loss": result.training_loss,
    "eval_loss": eval_results['eval_loss'],
    "exact_match_pct": eval_results['eval_exact_match'] * 100,
    "normalized_match_pct": eval_results['eval_normalized_match'] * 100,
}

with open("report_v5.json", "w") as f:
    json.dump(report, f, indent=2)

print(f"Saved to {OUTPUT}")

In [None]:
# Cell 17: Test
from transformers import pipeline

gen = pipeline(
    "text2text-generation",
    model=OUTPUT,
    device=0 if torch.cuda.is_available() else -1
)

def predict(q, schema):
    inp = f"translate to SQL: {q} | schema: {schema}"
    out = gen(inp, max_length=128, num_beams=4)
    return out[0]['generated_text']

tests = [
    ("How many singers are there?", serialize_schema("concert_singer")),
    ("Show all stadium names", serialize_schema("concert_singer")),
    ("Find pets older than 3 years", serialize_schema("pets_1")),
]

print("Test predictions:")
for q, s in tests:
    sql = predict(q, s)
    print(f"Q: {q}")
    print(f"SQL: {sql}\n")

In [None]:
# Cell 18: Zip
import shutil

shutil.make_archive("t2sql_v5_model", "zip", ".", "t2sql_final_v5")
print("Created: t2sql_v5_model.zip")

print("\n" + "=" * 60)
print("V5 COMPLETE")
print("=" * 60)
print(json.dumps(report, indent=2))