# Text-to-SQL V6 - Structured Output (Kaggle Split)

**Dataset:** WikiSQL from Salesforce GitHub
- **56k train / 8k validation / 15k test**
- Schema (table headers) in every example
- Downloads directly — no HuggingFace script issues

**Key Difference from V5:**
- **V5**: Model outputs raw SQL text
- **V6**: Model outputs **structured indices** (sel, agg, conds)
  - At inference, indices are converted to valid SQL
  - More generalizable across schemas

**Expected:** 50-70% accuracy

---

In [None]:
# Cell 1: Install
!pip install -q transformers>=4.35.0 datasets>=2.14.0 accelerate>=0.24.0
!pip install -q torch sentencepiece

In [None]:
# Cell 2: Setup
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import numpy as np
import json
import warnings
warnings.filterwarnings('ignore')

print("=" * 60)
print("TEXT-TO-SQL V6 - STRUCTURED OUTPUT (KAGGLE)")
print("=" * 60)
print(f"CUDA: {torch.cuda.is_available()}")

MODEL_NAME = "google-t5/t5-base" if torch.cuda.is_available() else "google-t5/t5-small"
print(f"Model: {MODEL_NAME}")

In [None]:
# Cell 3: Download WikiSQL directly from source
import urllib.request
import tarfile
import json
import os

DATA_URL = "https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2"
DATA_DIR = "./wikisql_data"

print("Downloading WikiSQL from Salesforce GitHub...")
if not os.path.exists(DATA_DIR):
    # Download
    urllib.request.urlretrieve(DATA_URL, "data.tar.bz2")
    print("Extracting...")
    with tarfile.open("data.tar.bz2", "r:bz2") as tar:
        tar.extractall(".")
    os.rename("data", DATA_DIR)
    os.remove("data.tar.bz2")
    print("Done!")
else:
    print("Already downloaded.")

print(f"\nFiles: {os.listdir(DATA_DIR)}")

In [None]:
# Cell 4: Load and process WikiSQL - STRUCTURED FORMAT
from datasets import Dataset

AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
OPS = ['=', '>', '<', '>=', '<=', '!=']  # Full operator set (WikiSQL uses 0-2)

def encode_structured_label(sel, agg, conds):
    """Encode structured output as text for seq2seq training."""
    parts = [f"SEL:{sel}", f"AGG:{agg}"]
    if conds:
        cond_str = ";".join([f"{c[0]},{c[1]},{c[2]}" for c in conds])
        parts.append(f"CONDS:{cond_str}")
    else:
        parts.append("CONDS:")
    return " | ".join(parts)

def load_wikisql_split(split_name):
    """Load a WikiSQL split with structured labels."""
    main_file = f"{DATA_DIR}/{split_name}.jsonl"
    tables_file = f"{DATA_DIR}/{split_name}.tables.jsonl"
    
    # Load tables
    with open(tables_file) as f:
        tables = {t["id"]: t for t in (json.loads(l) for l in f)}
    
    examples = []
    with open(main_file) as f:
        for line in f:
            row = json.loads(line)
            table = tables[row["table_id"]]
            
            question = row["question"]
            header = table["header"]
            table_name = table.get("name", "table")
            
            # Schema string
            schema = f"{table_name}({', '.join(header)})"
            
            # Input: question + schema
            input_text = f"translate to SQL: {question} | schema: {schema}"
            
            # Output: structured indices (encoded as string to avoid Arrow type issues)
            target_text = encode_structured_label(
                row["sql"]["sel"],
                row["sql"]["agg"],
                row["sql"]["conds"]
            )
            
            # Only include string fields for Arrow compatibility
            examples.append({
                "input_text": input_text,
                "target_text": target_text
            })
    
    return examples

print("Loading WikiSQL...")
train_data = load_wikisql_split("train")
val_data = load_wikisql_split("dev")
test_data = load_wikisql_split("test")

print(f"Train: {len(train_data)} | Val: {len(val_data)} | Test: {len(test_data)}")

# Convert to HuggingFace datasets
train_ds = Dataset.from_list(train_data)
val_ds = Dataset.from_list(val_data)

print(f"\nSample structured output:")
print(f"Input: {train_ds[0]['input_text'][:100]}...")
print(f"Target: {train_ds[0]['target_text']}")

In [None]:
# Cell 5: Tokenize
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize(batch):
    inputs = tokenizer(batch["input_text"], max_length=256, truncation=True)
    targets = tokenizer(text_target=batch["target_text"], max_length=128, truncation=True)
    inputs["labels"] = targets["input_ids"]
    return inputs

print("Tokenizing...")
train_tok = train_ds.map(tokenize, batched=True, remove_columns=train_ds.column_names)
val_tok = val_ds.map(tokenize, batched=True, remove_columns=val_ds.column_names)

print(f"Train: {len(train_tok)} | Val: {len(val_tok)}")

In [None]:
# Cell 6: Model
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.gradient_checkpointing_enable()
print(f"Loaded {MODEL_NAME} ({model.num_parameters():,} params)")

In [None]:
# Cell 7: Training Config
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100)

args = Seq2SeqTrainingArguments(
    output_dir="./t2sql_v6_structured",
    
    num_train_epochs=5,
    learning_rate=1e-4,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    max_grad_norm=1.0,
    
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=2,
    
    fp16=torch.cuda.is_available(),
    gradient_checkpointing=True,
    
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    
    predict_with_generate=True,
    generation_max_length=128,
    generation_num_beams=4,
    
    logging_steps=100,
    report_to="none",
    dataloader_num_workers=2,
    seed=42,
)

print(f"Epochs: {args.num_train_epochs} | LR: {args.learning_rate}")

In [None]:
# Cell 8: Structured Decoding & Metrics (using convert_wikisql.py approach)
import re

# Full operator set (WikiSQL only uses 0-2, but we support all)
OPS = ['=', '>', '<', '>=', '<=', '!=']

# SQL reserved words that need quoting
SQL_RESERVED = {
    'order', 'group', 'table', 'index', 'select', 'from', 'where', 'join',
    'left', 'right', 'inner', 'outer', 'on', 'as', 'and', 'or', 'not',
    'limit', 'offset', 'union', 'all', 'distinct', 'null', 'is', 'like',
    'between', 'in', 'exists', 'case', 'when', 'then', 'else', 'end',
    'count', 'sum', 'avg', 'min', 'max', 'having', 'by', 'asc', 'desc',
    'primary', 'key', 'foreign', 'references', 'constraint', 'unique',
    'check', 'default', 'create', 'alter', 'drop', 'insert', 'update', 'delete',
    'to', 'with', 'into', 'values', 'set', 'call', 'return', 'returning',
    'out', 'inout', 'procedure', 'function', 'trigger', 'view', 'schema',
    'current', 'timestamp', 'user', 'session', 'system', 'date', 'time',
    'datetime', 'year', 'month', 'day', 'hour', 'minute', 'second',
}

def clean_column_name(col, used_names=None):
    """Convert column name to valid SQL identifier (from convert_wikisql.py)."""
    cleaned = re.sub(r'[^a-zA-Z0-9_]', '_', col)
    cleaned = re.sub(r'_+', '_', cleaned).strip('_')
    if cleaned and cleaned[0].isdigit():
        cleaned = 'col_' + cleaned
    if not cleaned:
        cleaned = 'col'
    cleaned = cleaned.lower()
    if cleaned in SQL_RESERVED:
        cleaned = f'"{cleaned}"'
    if used_names is not None:
        base_name = cleaned
        suffix = 0
        while cleaned in used_names:
            suffix += 1
            if base_name.startswith('"') and base_name.endswith('"'):
                cleaned = f'"{base_name[1:-1]}_{suffix}"'
            else:
                cleaned = f'{base_name}_{suffix}'
        used_names.add(cleaned)
    return cleaned

def get_column_names(headers):
    """Generate clean column names with duplicate handling."""
    used_names = set()
    col_names = [clean_column_name(h, used_names) for h in headers]
    return col_names

def value_to_sql(value):
    """Convert a value to SQL literal (from convert_wikisql.py)."""
    if isinstance(value, str):
        escaped = value.replace("'", "''")
        return f"'{escaped}'"
    elif value is None:
        return 'NULL'
    else:
        return str(value)

def decode_structured_output(text):
    """Decode model output into structured components."""
    sel = agg = None
    conds = []
    try:
        for part in text.split(" | "):
            if part.startswith("SEL:"):
                sel = int(part[4:].strip())
            elif part.startswith("AGG:"):
                agg = int(part[4:].strip())
            elif part.startswith("CONDS:"):
                cond_str = part[6:].strip()
                if cond_str:
                    for c in cond_str.split(";"):
                        vals = c.split(",")
                        if len(vals) >= 3:
                            conds.append([int(vals[0]), int(vals[1]), vals[2]])
    except:
        pass
    return sel, agg, conds

def structured_to_sql(sel, agg, conds, header, table_name="table"):
    """Convert structured indices to SQL string (convert_wikisql.py approach)."""
    if sel is None or agg is None:
        return ""

    # Get clean column names
    col_names = get_column_names(header)
    col_map = {i: col_names[i] for i in range(len(col_names))}

    # SELECT clause
    col_name = col_map.get(sel, 'col')
    if agg == 0:
        sql = f"SELECT {col_name} FROM {table_name}"
    else:
        agg_op = AGG_OPS[agg] if agg < len(AGG_OPS) else ""
        sql = f"SELECT {agg_op}({col_name}) FROM {table_name}"

    # WHERE clause
    if conds:
        where_parts = []
        for c_idx, c_op, c_val in conds:
            if c_idx in col_map:
                col_name = col_map[c_idx]
                val_sql = value_to_sql(c_val)
                op_str = OPS[c_op] if c_op < len(OPS) else '='
                where_parts.append(f"{col_name} {op_str} {val_sql}")
        if where_parts:
            sql += " WHERE " + " AND ".join(where_parts)

    return sql.lower()

def compute_metrics(pred):
    """Compute exact match on structured components."""
    preds, labels = pred
    preds = np.clip(preds, 0, len(tokenizer)-1)
    pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    labels = np.clip(labels, 0, len(tokenizer)-1)
    label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
    exact = sum(p.strip() == l.strip() for p, l in zip(pred_str, label_str))
    return {"exact_match": exact / len(pred_str)}

# Test decoding with actual WikiSQL example
test_out = "SEL:5 | AGG:0 | CONDS:3,0,SOUTH AUSTRALIA"
sel, agg, conds = decode_structured_output(test_out)
print(f"Decoded: sel={sel}, agg={agg}, conds={conds}")

# Actual WikiSQL header for table 1-1000181-1
test_header = ['State/territory', 'Text/background colour', 'Format', 'Current slogan', 'Current series', 'Notes']
sql = structured_to_sql(sel, agg, conds, test_header, "t1")
print(f"SQL: {sql}")

In [None]:
# Cell 9: Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
)
print("Trainer ready")

In [None]:
# Cell 10: Quick sanity check
batch = collator([train_tok[i] for i in range(4)])
model.eval()
with torch.no_grad():
    loss = model(**{k: v.to(model.device) for k, v in batch.items()}).loss.item()
print(f"Initial loss: {loss:.2f}")
if loss < 10:
    print("✓ Ready to train")
model.train()

In [None]:
# Cell 11: TRAIN
print("=" * 50)
print("TRAINING V6 - STRUCTURED OUTPUT")
print(f"{len(train_tok)} examples | 5 epochs")
print("=" * 50)

torch.cuda.empty_cache() if torch.cuda.is_available() else None
result = trainer.train()

print(f"\nDone! Loss: {result.training_loss:.4f}")

In [None]:
# Cell 12: Evaluate
ev = trainer.evaluate()

print("=" * 50)
print("V6 STRUCTURED OUTPUT RESULTS")
print("=" * 50)
print(f"Eval Loss: {ev['eval_loss']:.4f}")
print(f"Exact Match: {ev['eval_exact_match']*100:.1f}%")

In [None]:
# Cell 13: Inference Demo
from transformers import pipeline

trainer.save_model("./t2sql_v6_structured_final")
tokenizer.save_pretrained("./t2sql_v6_structured_final")

gen = pipeline("text2text-generation", model="./t2sql_v6_structured_final",
               device=0 if torch.cuda.is_available() else -1)

def text_to_sql_structured(question, schema):
    """Convert question to SQL via structured prediction."""
    inp = f"translate to SQL: {question} | schema: {schema}"
    out = gen(inp, max_length=128, num_beams=4)[0]['generated_text']
    
    # Parse schema for headers
    table_start = schema.index('(')
    table_name = schema[:table_start].strip()
    headers = [h.strip() for h in schema[table_start+1:-1].split(',')]
    
    # Decode structured output
    sel, agg, conds = decode_structured_output(out)
    sql = structured_to_sql(sel, agg, conds, headers, table_name)
    
    return sql

tests = [
    ("How many players are there?", "players(id, name, age, team)"),
    ("What is the total population?", "countries(name, population, area)"),
    ("Show all products under $50", "products(id, name, price, category)"),
]

print("\nTest predictions (structured → SQL):")
for q, schema in tests:
    sql = text_to_sql_structured(q, schema)
    print(f"Q: {q}")
    print(f"SQL: {sql}\n")

In [None]:
# Cell 14: Save & Report
import shutil

report = {
    "version": "v6_structured_output",
    "dataset": "WikiSQL (direct download)",
    "train_examples": len(train_tok),
    "train_loss": result.training_loss,
    "eval_loss": ev['eval_loss'],
    "exact_match": ev['eval_exact_match']*100,
}

json.dump(report, open("report_v6.json", "w"), indent=2)
shutil.make_archive("t2sql_v6_structured", "zip", ".", "t2sql_v6_structured_final")

print("=" * 50)
print("SAVED")
print("=" * 50)
print(json.dumps(report, indent=2))
print("\nDownload: t2sql_v6_structured.zip")