# flan-t5-large-spider-text2sql

## Data Prep

In [1]:
import datasets
import torch
import numpy as np

MODEL_NAME = 'google/flan-t5-large'
SPIDER_DB = 'xlangai/spider'
SCHEMA_DB = 'richardr1126/spider-schema'

MAX_LENGTH = 512

dataset = datasets.load_dataset(SPIDER_DB)
schemas = datasets.load_dataset(SCHEMA_DB)

# creating schema_dic
schema_dic = {}
for index in range(0, len(schemas['train'])):
    item = schemas['train'][index]
    db_id = item['db_id']
    schema = item['Schema (values (type))']
    schema_dic[db_id] = schema


dataset = dataset.remove_columns(
    ['query_toks', 'query_toks_no_value', 'question_toks']
)


In [2]:
# adding schemas to the main dataset

import re

def get_cleaned_schema(schema_raw: str):
    """
    Transforms raw schema string:
    "table : col1 (type) , col2 (type) | table2 : col3 (type)"

    Into clean format:
    "table(col1, col2) | table2(col3)"
    """
    if not schema_raw:
        return ""

    # Split into separate tables by "|"
    raw_tables = schema_raw.split('|')
    cleaned_parts = []

    for raw_table in raw_tables:
        # Split Table Name from Columns
        parts = raw_table.split(':')

        if len(parts) != 2:
            continue # Skip malformed entries

        table_name = parts[0].strip()
        raw_columns = parts[1].strip()

        # Clean Columns (Remove types like "(text)", "(number)")
        clean_columns = re.sub(r'\s*\([^)]*\)', '', raw_columns)

        # Format as "table_name(col1, col2)"
        cols_list = [c.strip() for c in clean_columns.split(',')]
        formatted_cols = ", ".join(cols_list)

        cleaned_parts.append(f"{table_name}({formatted_cols})")

    return " | ".join(cleaned_parts)



def fix_question(question, db_id, prefix):
    '''
        fixes a single question
        adds prefix to the beginning
        adds the schema at the end
        the output format looks like this:
            <prefix>: <question> | Schemas: <schema-1>(column-1, column-2, ...) | <schema-2>(column-1, column-2, ...)
    '''
    schema = schema_dic[db_id]
    cleanted_schema = get_cleaned_schema(schema)
    output = prefix + ": " + question + " | " + "Schemas: " + cleanted_schema
    return output


def preprocess(batch):
    task_prefix = 'Translate English to SQL'
    query = batch['query']

    fixed_questions_list = []
    for sample_db_id, sample_question in zip(batch['db_id'], batch['question']):
        fixed_questions_list.append(fix_question(sample_question, sample_db_id, task_prefix))

    model_inputs = tokenizer(
        fixed_questions_list,
        truncation=True,
    )

    labels = tokenizer(
        text_target=query,
        truncation=True,
    )

    model_inputs["labels"] = labels["input_ids"]


    return model_inputs


In [3]:
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# apply preprocessing

tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path=MODEL_NAME,
    use_fast=True,
    model_max_length=MAX_LENGTH
)

dataset['train'] = dataset['train'].map(
    preprocess,
    batched=True,
    batch_size=64,
    remove_columns=dataset['train'].column_names,
    num_proc=(os.cpu_count() // 3) + 1
)

dataset['validation'] = dataset['validation'].map(
    preprocess,
    batched=True,
    batch_size=64,
    remove_columns=dataset['validation'].column_names,
    num_proc=(os.cpu_count() // 3) + 1
)

# split validation set to test and validation
val_split = dataset['validation'].train_test_split(test_size=0.5, seed=42)
dataset['test'] = val_split['train']
dataset['validation'] = val_split['test']
del val_split
dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 7000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 517
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 517
    })
})

## Training Setup

### compute_metrics & CallBacks

In [4]:
!pip install -U evaluate -q
!pip install -U rouge_score -q

import evaluate
import numpy as np
from transformers import TrainerCallback

rouge_metric = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    if isinstance(preds, tuple):
        preds = preds[0]

    preds = np.clip(preds, 0, tokenizer.vocab_size - 1)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    def normalize_sql(s):
        return s.lower().replace(" ", "").strip()

    exact_matches = [
        1 if normalize_sql(p) == normalize_sql(l) else 0
        for p, l in zip(decoded_preds, decoded_labels)
    ]

    result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)

    return {
        "exact_match": sum(exact_matches) / len(exact_matches),
        "rouge1": result["rouge1"],
    }


class PrinterCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        print("\n*** DEBUGGING GENERATION ***")

        # Hardcode a fixed example — no dataset access needed
        input_text = "Translate English to SQL: How many players are from each country? | Schemas: players(player_id, first_name, last_name, hand, birth_date, country_code) | matches(best_of, draw_size, loser_age, loser_entry, loser_hand, loser_ht, loser_id, loser_ioc, loser_name, loser_rank, loser_rank_points, loser_seed, match_num, minutes, round, score, surface, tourney_date, tourney_id, tourney_level, tourney_name, winner_age, winner_entry, winner_hand, winner_ht, winner_id, winner_ioc, winner_name, winner_rank, winner_rank_points, winner_seed, year) | rankings(ranking_date, ranking, player_id, ranking_points, tours"
        gold_label = "SELECT count(*) , country_code FROM players GROUP BY country_code"

        inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            gen_tokens = model.generate(
                **inputs,
                max_length=GENERATION_MAX_LENGTH,
                num_beams=4,
                repetition_penalty=1.2,
            )

        decoded_pred = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
        print(f"QUESTION: How many players are from each country?")
        print(f"PREDICTION: {decoded_pred}")
        print(f"GOLD LABEL: {gold_label}")
        print("******************************\n")

### Training Args

In [5]:
from transformers import TrainingArguments, Seq2SeqTrainingArguments, GenerationConfig
from pathlib import Path

SAVE_DIR = Path('./models')
TRAIN_BATCH_SIZE = 4 # BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS -> true batch size
EVAL_BATCH_SIZE = 64
LR = 0.0003
MAX_GRAD_NORM = 2
EPOCH = 10
WEIGHT_DECAY = 0.0005
GRADIENT_ACCUMULATION_STEPS = 4
WARMUP_RATIO = 0.1
LABEL_SMOOTHING = 0
GENERATION_MAX_LENGTH = 128

training_args = Seq2SeqTrainingArguments(
    output_dir=SAVE_DIR,

    # --- OPTIMIZER ---
    # optim="adafactor",
    optim="adamw_torch",
    # torch_compile=True,
    # torch_compile_backend="inductor",
    # torch_compile_mode="reduce-overhead",

    # --- EVALUATION STRATEGY ---
    eval_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=4,
    load_best_model_at_end=True,
    metric_for_best_model='exact_match',
    greater_is_better=True,

    # --- BATCHING & OPTIMIZATION ---
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LR,
    max_grad_norm=MAX_GRAD_NORM,
    num_train_epochs=EPOCH,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    label_smoothing_factor=LABEL_SMOOTHING,
    lr_scheduler_type='cosine',

    # --- GENERATION ---
    predict_with_generate=True,
    generation_max_length=GENERATION_MAX_LENGTH,

    # --- LOGGING & HARDWARE ---
    logging_steps=2,            # 1 is too noisy, 10 is cleaner
    fp16=False,
    bf16=True,
    report_to='none',

    # -----
    generation_config=GenerationConfig(
        # repetition_penalty=1.2,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        decoder_start_token_id=tokenizer.pad_token_id,
        num_beams=4,
    ),
    dataloader_num_workers=(os.cpu_count() //3 * 2),
    group_by_length=True,
)

warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.


### Model Setup

In [6]:
# loading model

#!pip install peft

from transformers import AutoModelForSeq2SeqLM
from peft import LoraConfig, get_peft_model, TaskType

model = AutoModelForSeq2SeqLM.from_pretrained(
    pretrained_model_name_or_path=MODEL_NAME,
    # torch_dtype=torch.bfloat16  # Load directly in BF16
)

model.config.decoder_start_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q", "v", "k", "o", "wi_0", "wi_1", "wo"],  # include FFN this time
    lora_dropout=0.0,
    bias='none',
    task_type=TaskType.SEQ_2_SEQ_LM,
)

model = get_peft_model(model, lora_config)

Loading weights:   0%|          | 0/558 [00:00<?, ?it/s]



In [7]:
# # Cast base model to bf16, but keep LoRA adapters in fp32
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         param.data = param.data.to(torch.float32)  # LoRA stays fp32
#     else:
#         param.data = param.data.to(torch.bfloat16)  # base stays bf16


# # Verify
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(f"TRAINABLE: {param.dtype}")  # should be float32
#         break
# for name, param in model.named_parameters():
#     if not param.requires_grad:
#         print(f"FROZEN: {param.dtype}")  # should be bfloat16
#         break

In [8]:
# from torch.optim import AdamW

# # Recreate batch
# loader = trainer.get_train_dataloader()
# batch = next(iter(loader))
# batch = {k: v.to(model.device) for k, v in batch.items()}

# optimizer = AdamW(model.parameters(), lr=1e-3)

# for step in range(100):
#     optimizer.zero_grad()
#     outputs = model(**batch)
#     loss = outputs.loss
#     loss.backward()
#     optimizer.step()
#     if step % 10 == 0:
#         print(f"Step {step}: loss={loss.item():.4f}")

In [9]:
# data collator
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq

data_collator_fn = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    label_pad_token_id=-100, # default is -100 (-100 will be automatically ignored by PyTorch loss functions)
    pad_to_multiple_of=8,  # Optimization for TPU/GPU cores
)

# trainer
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator_fn,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics,
    callbacks=[PrinterCallback],
)

## Train

In [None]:
results = trainer.train()

Epoch,Training Loss,Validation Loss,Exact Match,Rouge1
1,11.822417,3.050445,0.001934,0.307812



*** DEBUGGING GENERATION ***
QUESTION: How many players are from each country?
PREDICTION: SELECT count(*) FROM player_id
GOLD LABEL: SELECT count(*) , country_code FROM players GROUP BY country_code
******************************

