# Text-to-SQL V2 - Optimized for Higher Accuracy

**Fixes Applied:**
- Learning rate: 5e-5 → 1e-4 (research-backed)
- Epochs: 10 → 20 (proper convergence)
- Removed early stopping (let it train fully)
- Better SQL normalization (case-insensitive matching)
- Metric: eval_loss (more stable than exact_match)
- Added component-level accuracy

**Target:** 30-45% exact match

---

In [None]:
# Cell 1: Install
!pip install -q transformers>=4.35.0 datasets>=2.14.0 accelerate>=0.24.0
!pip install -q torch sentencepiece sqlparse pandas numpy tqdm

In [None]:
# Cell 2: Imports and Setup
import torch
import numpy as np
import pandas as pd
import json
import re
import warnings
from collections import defaultdict

warnings.filterwarnings('ignore')

print("=" * 60)
print("TEXT-TO-SQL V2 - OPTIMIZED")
print("=" * 60)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    GPU_NAME = torch.cuda.get_device_name(0)
    GPU_MEM = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {GPU_NAME} ({GPU_MEM:.1f} GB)")
    MODEL_NAME = "google-t5/t5-base"
else:
    MODEL_NAME = "google-t5/t5-small"
    print("WARNING: No GPU!")

print(f"\nModel: {MODEL_NAME}")
print("=" * 60)

In [None]:
# Cell 3: Load Spider Dataset
from datasets import load_dataset

print("Loading Spider dataset...")
dataset = None

for source in ["xlangai/spider", "spider"]:
    try:
        dataset = load_dataset(source)
        print(f"Loaded from {source}")
        break
    except:
        continue

if dataset is None:
    raise Exception("Could not load Spider dataset")

print(f"Train: {len(dataset['train'])} | Validation: {len(dataset['validation'])}")

In [None]:
# Cell 4: Improved Schema Serialization
def serialize_schema(example):
    """
    Enhanced schema serialization with table.column format.
    """
    try:
        table_names = example.get('db_table_names', [])
        column_names = example.get('db_column_names', [])
        column_types = example.get('db_column_types', [])
        
        table_cols = defaultdict(list)
        
        for idx, col_info in enumerate(column_names):
            if not isinstance(col_info, (list, tuple)) or len(col_info) < 2:
                continue
            
            table_idx, col_name = col_info[0], col_info[1]
            
            if table_idx == -1 or table_idx >= len(table_names):
                continue
            
            table_name = table_names[table_idx].lower()
            col_str = str(col_name).lower()
            
            # Add type info
            if idx < len(column_types) and column_types[idx]:
                col_str += f":{column_types[idx]}"
            
            table_cols[table_name].append(col_str)
        
        if table_cols:
            parts = [f"{tbl}[{','.join(cols)}]" for tbl, cols in table_cols.items()]
            return " ".join(parts)
        return example.get('db_id', 'db')
    except:
        return example.get('db_id', 'db')

print("Schema function defined.")

In [None]:
# Cell 5: SQL Normalization (Critical for accuracy)
def normalize_sql(sql):
    """
    Aggressive SQL normalization for better matching.
    """
    sql = sql.strip().lower()
    
    # Normalize whitespace
    sql = re.sub(r'\s+', ' ', sql)
    
    # Normalize around operators
    sql = re.sub(r'\s*([=<>!]+)\s*', r' \1 ', sql)
    sql = re.sub(r'\s*,\s*', ', ', sql)
    sql = re.sub(r'\s*\(\s*', '(', sql)
    sql = re.sub(r'\s*\)\s*', ') ', sql)
    
    # Remove trailing semicolon
    sql = sql.rstrip(';').strip()
    
    # Normalize multiple spaces again
    sql = re.sub(r'\s+', ' ', sql).strip()
    
    return sql

# Test
test_sql = "SELECT COUNT(*)  FROM  students   WHERE gpa>3.5;"
print(f"Before: {test_sql}")
print(f"After:  {normalize_sql(test_sql)}")

In [None]:
# Cell 6: Preprocessing
def preprocess(example):
    question = str(example.get('question', '')).strip()
    sql = str(example.get('query', '')).strip()
    schema = serialize_schema(example)
    
    # Input format
    input_text = f"translate to SQL: {question} | schema: {schema}"
    
    # Normalize target SQL
    target_text = normalize_sql(sql)
    
    return {
        "input_text": input_text,
        "target_text": target_text
    }

print("Preprocessing...")
processed = dataset.map(preprocess, num_proc=4)

print(f"\nSample:")
print(f"Input: {processed['train'][0]['input_text'][:100]}...")
print(f"Target: {processed['train'][0]['target_text']}")

In [None]:
# Cell 7: Tokenization
from transformers import AutoTokenizer

print(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

MAX_INPUT = 512
MAX_TARGET = 256

def tokenize(examples):
    inputs = tokenizer(
        examples["input_text"],
        max_length=MAX_INPUT,
        truncation=True,
        padding=False
    )
    targets = tokenizer(
        text_target=examples["target_text"],
        max_length=MAX_TARGET,
        truncation=True,
        padding=False
    )
    inputs["labels"] = targets["input_ids"]
    return inputs

print("Tokenizing...")
tokenized = processed.map(
    tokenize,
    batched=True,
    num_proc=4,
    remove_columns=processed['train'].column_names
)

print(f"Done. Columns: {tokenized['train'].column_names}")

In [None]:
# Cell 8: Load Model
from transformers import AutoModelForSeq2SeqLM

print(f"Loading model: {MODEL_NAME}")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.gradient_checkpointing_enable()

print(f"Parameters: {model.num_parameters():,}")

In [None]:
# Cell 9: OPTIMIZED Training Config
from transformers import (
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    padding=True
)

# OPTIMIZED SETTINGS based on research papers
training_args = Seq2SeqTrainingArguments(
    output_dir="./t2sql_v2",
    
    # ===== KEY CHANGES =====
    num_train_epochs=20,           # 10 → 20 (more training)
    learning_rate=1e-4,            # 5e-5 → 1e-4 (faster learning)
    warmup_ratio=0.05,             # 0.1 → 0.05 (shorter warmup)
    # =======================
    
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    
    fp16=torch.cuda.is_available(),
    gradient_checkpointing=True,
    label_smoothing_factor=0.1,
    
    # Use eval_loss (not exact_match) - more stable
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",  # Changed from exact_match
    greater_is_better=False,             # Lower loss = better
    
    predict_with_generate=True,
    generation_max_length=MAX_TARGET,
    generation_num_beams=4,
    
    logging_steps=50,
    report_to="none",
    dataloader_num_workers=2,
    seed=42,
)

print("=" * 60)
print("OPTIMIZED CONFIGURATION")
print("=" * 60)
print(f"Epochs: {training_args.num_train_epochs} (was 10)")
print(f"Learning rate: {training_args.learning_rate} (was 5e-5)")
print(f"Warmup: {training_args.warmup_ratio} (was 0.1)")
print(f"Metric: eval_loss (was exact_match)")
print(f"Early stopping: DISABLED")
print("=" * 60)

In [None]:
# Cell 10: Improved Metrics
VOCAB_SIZE = len(tokenizer)

def compute_metrics(eval_pred):
    """
    Compute multiple metrics for better evaluation.
    """
    predictions, labels = eval_pred
    
    # Clip to valid range
    predictions = np.clip(predictions, 0, VOCAB_SIZE - 1)
    
    # Decode
    try:
        pred_texts = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    except:
        return {"exact_match": 0.0, "normalized_match": 0.0}
    
    # Handle labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    labels = np.clip(labels, 0, VOCAB_SIZE - 1)
    
    try:
        label_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
    except:
        return {"exact_match": 0.0, "normalized_match": 0.0}
    
    exact_matches = 0
    normalized_matches = 0
    total = len(pred_texts)
    
    for pred, label in zip(pred_texts, label_texts):
        # Exact match (strict)
        if pred.strip() == label.strip():
            exact_matches += 1
        
        # Normalized match (case-insensitive, whitespace-normalized)
        pred_norm = normalize_sql(pred)
        label_norm = normalize_sql(label)
        if pred_norm == label_norm:
            normalized_matches += 1
    
    return {
        "exact_match": exact_matches / total if total > 0 else 0.0,
        "normalized_match": normalized_matches / total if total > 0 else 0.0
    }

print("Metrics function ready (includes normalized matching).")

In [None]:
# Cell 11: Initialize Trainer (NO early stopping)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    # NO callbacks - let it train fully
)

print("Trainer ready.")
print(f"Train: {len(tokenized['train'])} | Eval: {len(tokenized['validation'])}")

In [None]:
# Cell 12: Quick Verification
print("Verifying setup...")

# Test batch
test_batch = [tokenized['train'][i] for i in range(2)]
collated = data_collator(test_batch)
print(f"✓ Collator OK")

# Test forward
model.eval()
with torch.no_grad():
    out = model(**{k: v.to(model.device) for k, v in collated.items()})
print(f"✓ Forward OK (loss: {out.loss.item():.2f})")

# Test metrics
fake_pred = np.random.randint(0, VOCAB_SIZE, (2, 10))
fake_label = np.random.randint(0, VOCAB_SIZE, (2, 10))
m = compute_metrics((fake_pred, fake_label))
print(f"✓ Metrics OK")

print("\n" + "=" * 60)
print("ALL CHECKS PASSED - READY TO TRAIN")
print("=" * 60)

In [None]:
# Cell 13: TRAIN
print("=" * 60)
print("STARTING OPTIMIZED TRAINING")
print("=" * 60)
print(f"Model: {MODEL_NAME}")
print(f"Epochs: 20 | LR: 1e-4 | Batch: 32")
print(f"Expected time: ~2-3 hours on P100")
print("=" * 60)
print("\nYou can close the browser. Training will continue.\n")

if torch.cuda.is_available():
    torch.cuda.empty_cache()

result = trainer.train()

print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)
print(f"Train loss: {result.training_loss:.4f}")
print(f"Time: {result.metrics['train_runtime']/3600:.2f} hours")

In [None]:
# Cell 14: Evaluate
print("Evaluating...\n")

eval_results = trainer.evaluate()

print("=" * 60)
print("RESULTS")
print("=" * 60)
print(f"Eval Loss: {eval_results['eval_loss']:.4f}")
print(f"Exact Match: {eval_results['eval_exact_match']*100:.2f}%")
print(f"Normalized Match: {eval_results['eval_normalized_match']*100:.2f}%")
print("=" * 60)

# Grade
nm = eval_results['eval_normalized_match'] * 100
if nm >= 40:
    print("\nGrade: EXCELLENT")
elif nm >= 30:
    print("\nGrade: GOOD")
elif nm >= 20:
    print("\nGrade: ACCEPTABLE")
else:
    print("\nGrade: NEEDS MORE TRAINING")

In [None]:
# Cell 15: Save Model
OUTPUT = "./t2sql_final_v2"

print(f"Saving to {OUTPUT}...")
trainer.save_model(OUTPUT)
tokenizer.save_pretrained(OUTPUT)

# Report
report = {
    "version": "v2_optimized",
    "model": MODEL_NAME,
    "epochs": 20,
    "learning_rate": "1e-4",
    "train_loss": result.training_loss,
    "eval_loss": eval_results['eval_loss'],
    "exact_match_pct": eval_results['eval_exact_match'] * 100,
    "normalized_match_pct": eval_results['eval_normalized_match'] * 100,
    "training_hours": result.metrics['train_runtime'] / 3600
}

with open("report_v2.json", "w") as f:
    json.dump(report, f, indent=2)

print("Saved!")

In [None]:
# Cell 16: Test Predictions
from transformers import pipeline

print("Testing model...\n")

gen = pipeline(
    "text2text-generation",
    model=OUTPUT,
    device=0 if torch.cuda.is_available() else -1
)

def predict(q, s):
    inp = f"translate to SQL: {q} | schema: {s}"
    out = gen(inp, max_length=256, num_beams=4)
    return out[0]['generated_text']

tests = [
    ("How many students are there?", "students[id,name,age,gpa]"),
    ("Find students with GPA above 3.5", "students[id,name,gpa]"),
    ("List all departments and their average salary", "employees[id,name,dept,salary]"),
    ("Show the names of courses with more than 100 students", "courses[id,name,student_count]"),
]

for q, s in tests:
    sql = predict(q, s)
    print(f"Q: {q}")
    print(f"SQL: {sql}")
    print()

In [None]:
# Cell 17: Zip Model & Final Report
import shutil

# Zip the model folder
print("Zipping model...")
shutil.make_archive("t2sql_v2_model", "zip", ".", "t2sql_final_v2")
print("Created: t2sql_v2_model.zip")

print("\\n" + "=" * 60)
print("FINAL REPORT V2")
print("=" * 60)
print(json.dumps(report, indent=2))
print("=" * 60)
print("\\nFiles to download:")
print("  1. t2sql_v2_model.zip (zipped model)")
print("  2. t2sql_final_v2/ (model folder)")
print("  3. report_v2.json (metrics)")
print("\\nTraining complete!")