## Importing Libraries and Setting up Paths

In [1]:
import json
import csv
import time
import sqlite3
import os
from pathlib import Path

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
from datasets import Dataset
import sqlglot
from sqlglot import parse_one

# Check environment
print("=" * 60)
print("ENVIRONMENT CHECK")
print("=" * 60)
print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {__import__('transformers').__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    DEVICE = "cuda"
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    DEVICE = "cpu"
    print("WARNING: No GPU detected! Training will be very slow.")

print(f"Device selected: {DEVICE}")
print("=" * 60)

# ============================================================
# File Paths
# ============================================================

# Data paths
TRAIN_JSONL = Path("train_text2sql.jsonl")
VAL_JSONL = Path("val_text2sql.jsonl")
TEST_JSONL = Path("test_hospital_1.jsonl")

# Database path
SQLITE_DB = Path("hospital_1.sqlite")

# Output directories
OUTPUT_DIR = Path("finetuned_flant5")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

RESULTS_CSV = Path("results_hospital_1_finetuned_flant5.csv")

print("\nFILE PATHS:")
print("-" * 60)
print(f"Training data:      {TRAIN_JSONL}")
print(f"Validation data:    {VAL_JSONL}")
print(f"Test data:          {TEST_JSONL}")
print(f"SQLite DB:          {SQLITE_DB}")
print(f"Output directory:   {OUTPUT_DIR}")
print(f"Results CSV:        {RESULTS_CSV}")
print("-" * 60)

# Verify files exist
print("\nFILE VERIFICATION:")
print("-" * 60)
files_to_check = {
    "Training data": TRAIN_JSONL,
    "Validation data": VAL_JSONL,
    "Test data": TEST_JSONL,
    "SQLite DB": SQLITE_DB,
}

all_exist = True
for name, path in files_to_check.items():
    exists = path.exists()
    status = "✓ EXISTS" if exists else "✗ MISSING"
    print(f"{name:20s}: {status}")
    if not exists:
        all_exist = False

print("-" * 60)

if all_exist:
    print("\n✅ All required files found!")
else:
    print("\n⚠️  WARNING: Some files are missing. Please upload them before proceeding.")

ENVIRONMENT CHECK
PyTorch version: 2.8.0+cu126
Transformers version: 4.57.1
CUDA available: True
CUDA device: NVIDIA A100-SXM4-40GB
GPU Memory: 39.56 GB
Device selected: cuda

FILE PATHS:
------------------------------------------------------------
Training data:      train_text2sql.jsonl
Validation data:    val_text2sql.jsonl
Test data:          test_hospital_1.jsonl
SQLite DB:          hospital_1.sqlite
Output directory:   finetuned_flant5
Results CSV:        results_hospital_1_finetuned_flant5.csv
------------------------------------------------------------

FILE VERIFICATION:
------------------------------------------------------------
Training data       : ✓ EXISTS
Validation data     : ✓ EXISTS
Test data           : ✓ EXISTS
SQLite DB           : ✓ EXISTS
------------------------------------------------------------

✅ All required files found!


## Setting up Configuration

In [2]:
print("=" * 60)
print("CONFIGURATION PARAMETERS")
print("=" * 60)

# ============================================================
# Model Selection
# ============================================================

# We'll use the base Flan-T5 model and fine-tune it on our data
BASE_MODEL_NAME = "juierror/flan-t5-text2sql-with-schema-v2"  # ~250M parameters

# Alternative options (uncomment if you want):
# BASE_MODEL_NAME = "google/flan-t5-small"  # ~80M parameters (faster, less accurate)
# BASE_MODEL_NAME = "google/flan-t5-large"  # ~780M parameters (slower, more accurate)

print("\nModel Selection:")
print("-" * 60)
print(f"Base model: {BASE_MODEL_NAME}")

# ============================================================
# Training Hyperparameters
# ============================================================

# NUM_EPOCHS = 2
BATCH_SIZE = 8  # Adjust based on your GPU memory
# LEARNING_RATE = 1e-5
NUM_EPOCHS = 4
LEARNING_RATE = 2e-5
WARMUP_STEPS = 500
WEIGHT_DECAY = 0.01
LOGGING_STEPS = 50
SAVE_STEPS = 500
EVAL_STEPS = 500

print("\nTraining Hyperparameters:")
print("-" * 60)
print(f"Epochs:              {NUM_EPOCHS}")
print(f"Batch size:          {BATCH_SIZE}")
print(f"Learning rate:       {LEARNING_RATE}")
print(f"Warmup steps:        {WARMUP_STEPS}")
print(f"Weight decay:        {WEIGHT_DECAY}")
print(f"Logging steps:       {LOGGING_STEPS}")
print(f"Save steps:          {SAVE_STEPS}")
print(f"Eval steps:          {EVAL_STEPS}")

# ============================================================
# Generation Parameters (for inference)
# ============================================================

MAX_INPUT_LENGTH = 512   # Maximum tokens for input (question + schema)
MAX_TARGET_LENGTH = 256  # Maximum tokens for output (SQL)

GEN_MAX_LENGTH = 256
GEN_NUM_BEAMS = 4
GEN_TEMPERATURE = 0.0  # 0 = greedy decoding

print("\nGeneration Parameters:")
print("-" * 60)
print(f"Max input length:    {MAX_INPUT_LENGTH}")
print(f"Max target length:   {MAX_TARGET_LENGTH}")
print(f"Generation max len:  {GEN_MAX_LENGTH}")
print(f"Num beams:           {GEN_NUM_BEAMS}")
print(f"Temperature:         {GEN_TEMPERATURE}")

# ============================================================
# Prompt Template
# ============================================================

# This is how we'll format the input to the model
PROMPT_TEMPLATE = """Question: {question}

Schema:
{schema}

SQL:"""

print("\nPrompt Template:")
print("-" * 60)
print(PROMPT_TEMPLATE.format(
    question="<question here>",
    schema="<schema here>"
))

# ============================================================
# Other Settings
# ============================================================

SEED = 42
FP16 = True if DEVICE == "cuda" else False  # Use mixed precision on GPU

print("\nOther Settings:")
print("-" * 60)
print(f"Random seed:         {SEED}")
print(f"FP16 (mixed prec):   {FP16}")

# Set random seed
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("\n✅ Random seed set for reproducibility")

CONFIGURATION PARAMETERS

Model Selection:
------------------------------------------------------------
Base model: juierror/flan-t5-text2sql-with-schema-v2

Training Hyperparameters:
------------------------------------------------------------
Epochs:              4
Batch size:          8
Learning rate:       2e-05
Warmup steps:        500
Weight decay:        0.01
Logging steps:       50
Save steps:          500
Eval steps:          500

Generation Parameters:
------------------------------------------------------------
Max input length:    512
Max target length:   256
Generation max len:  256
Num beams:           4
Temperature:         0.0

Prompt Template:
------------------------------------------------------------
Question: <question here>

Schema:
<schema here>

SQL:

Other Settings:
------------------------------------------------------------
Random seed:         42
FP16 (mixed prec):   True

✅ Random seed set for reproducibility


## Loading Base Model and Tokenizer

In [3]:
print("=" * 60)
print("LOADING BASE MODEL & TOKENIZER")
print("=" * 60)

# ============================================================
# Load Tokenizer
# ============================================================

print("\nLoading tokenizer...")
print("-" * 60)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

print(f"✅ Tokenizer loaded: {BASE_MODEL_NAME}")
print(f"   Vocab size: {len(tokenizer)}")
print(f"   Model max length: {tokenizer.model_max_length}")

# Test tokenization
test_text = "SELECT * FROM Department WHERE DepartmentID = 1 ;"
test_tokens = tokenizer(test_text, return_tensors="pt")

print(f"\nTokenization test:")
print(f"   Input: {test_text}")
print(f"   Token IDs shape: {test_tokens['input_ids'].shape}")
print(f"   Token IDs: {test_tokens['input_ids'][0][:20].tolist()}...")

# ============================================================
# Load Model
# ============================================================

print("\n" + "-" * 60)
print("Loading model...")
print("-" * 60)

model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_NAME)

# Move to device
model = model.to(DEVICE)

print(f"✅ Model loaded: {BASE_MODEL_NAME}")

# ============================================================
# Model Statistics
# ============================================================

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"   Total parameters:      {total_params:,}")
print(f"   Trainable parameters:  {trainable_params:,}")
print(f"   Model size (approx):   {total_params * 4 / 1024 / 1024:.2f} MB")

# ============================================================
# Test Generation (before fine-tuning)
# ============================================================

print("\n" + "-" * 60)
print("Testing generation (before fine-tuning)...")
print("-" * 60)

test_prompt = PROMPT_TEMPLATE.format(
    question="How many physicians are there?",
    schema="""Database: hospital_1
Tables:
- Physician(EmployeeID*, Name, Position, SSN)"""
)

print(f"Test prompt:\n{test_prompt}\n")

# Tokenize and generate
inputs = tokenizer(test_prompt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(DEVICE)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_length=GEN_MAX_LENGTH,
        num_beams=GEN_NUM_BEAMS,
        temperature=GEN_TEMPERATURE if GEN_TEMPERATURE > 0 else 1.0,
        do_sample=False
    )

generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Generated SQL (baseline, before fine-tuning):")
print(f"   {generated_sql}")

print("\n✅ Model is ready for fine-tuning!")

LOADING BASE MODEL & TOKENIZER

Loading tokenizer...
------------------------------------------------------------
✅ Tokenizer loaded: juierror/flan-t5-text2sql-with-schema-v2
   Vocab size: 32101
   Model max length: 512

Tokenization test:
   Input: SELECT * FROM Department WHERE DepartmentID = 1 ;
   Token IDs shape: torch.Size([1, 16])
   Token IDs: [3, 23143, 14196, 1429, 21680, 1775, 549, 17444, 427, 1775, 4309, 3274, 209, 3, 117, 1]...

------------------------------------------------------------
Loading model...
------------------------------------------------------------
✅ Model loaded: juierror/flan-t5-text2sql-with-schema-v2

Model Statistics:
   Total parameters:      247,536,384
   Trainable parameters:  247,536,384
   Model size (approx):   944.28 MB

------------------------------------------------------------
Testing generation (before fine-tuning)...
------------------------------------------------------------
Test prompt:
Question: How many physicians are there?

Schem

## Data Loading and Pre-processing

In [4]:
print("=" * 60)
print("DATA LOADING AND PRE-PROCESSING")
print("=" * 60)

# ============================================================
# Load JSONL Files
# ============================================================

def load_jsonl(path):
    """Load JSONL file and return list of dictionaries."""
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows

print("\nLoading datasets...")
print("-" * 60)

train_data = load_jsonl(TRAIN_JSONL)
val_data = load_jsonl(VAL_JSONL)
test_data = load_jsonl(TEST_JSONL)

print(f"Training examples:   {len(train_data)}")
print(f"Validation examples: {len(val_data)}")
print(f"Test examples:       {len(test_data)}")

# Show sample
print(f"\nSample training example:")
print(f"   Question: {train_data[0]['question']}")
print(f"   Gold SQL: {train_data[0]['gold_query']}")
print(f"   Schema (first 100 chars): {train_data[0]['schema_serialized'][:100]}...")

# ============================================================
# Preprocessing Function
# ============================================================

def preprocess_function(examples):
    """
    Preprocess examples for fine-tuning.

    Args:
        examples: Dictionary with lists of questions, schemas, and SQL queries

    Returns:
        Dictionary with tokenized inputs and labels
    """
    # Build input prompts
    inputs = []
    for question, schema in zip(examples['question'], examples['schema_serialized']):
        prompt = PROMPT_TEMPLATE.format(question=question, schema=schema)
        inputs.append(prompt)

    # Tokenize inputs
    model_inputs = tokenizer(
        inputs,
        max_length=MAX_INPUT_LENGTH,
        truncation=True,
        padding=False  # We'll pad dynamically in the data collator
    )

    # Tokenize targets (SQL queries)
    labels = tokenizer(
        text_target=examples['gold_query'],
        max_length=MAX_TARGET_LENGTH,
        truncation=True,
        padding=False
    )

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

# ============================================================
# Convert to HuggingFace Dataset Format
# ============================================================

print("\n" + "-" * 60)
print("Converting to HuggingFace Dataset format...")
print("-" * 60)

# Convert lists of dicts to dict of lists (HuggingFace format)
train_dataset = Dataset.from_dict({
    'question': [ex['question'] for ex in train_data],
    'schema_serialized': [ex['schema_serialized'] for ex in train_data],
    'gold_query': [ex['gold_query'] for ex in train_data]
})

val_dataset = Dataset.from_dict({
    'question': [ex['question'] for ex in val_data],
    'schema_serialized': [ex['schema_serialized'] for ex in val_data],
    'gold_query': [ex['gold_query'] for ex in val_data]
})

print(f"✅ Datasets converted")
print(f"   Train dataset: {len(train_dataset)} examples")
print(f"   Val dataset:   {len(val_dataset)} examples")

# ============================================================
# Tokenize Datasets
# ============================================================

print("\n" + "-" * 60)
print("Tokenizing datasets (this may take a few minutes)...")
print("-" * 60)

# # IN COLAB: Use multiprocessing for faster tokenization
# tokenized_train = train_dataset.map(
#     preprocess_function,
#     batched=True,
#     num_proc=4,  # Use 4 processes in Colab
#     remove_columns=train_dataset.column_names,
#     desc="Tokenizing training data"
# )

# tokenized_val = val_dataset.map(
#     preprocess_function,
#     batched=True,
#     num_proc=4,  # Use 4 processes in Colab
#     remove_columns=val_dataset.column_names,
#     desc="Tokenizing validation data"
# )

# ON LAPTOP: Use single process (no multiprocessing on Mac sometimes causes issues)
tokenized_train = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Tokenizing training data"
)

tokenized_val = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Tokenizing validation data"
)

print(f"✅ Tokenization complete")
print(f"   Tokenized train: {len(tokenized_train)} examples")
print(f"   Tokenized val:   {len(tokenized_val)} examples")

# Show tokenized example
print(f"\nTokenized example (first training sample):")
print(f"   Input IDs length:  {len(tokenized_train[0]['input_ids'])}")
print(f"   Label IDs length:  {len(tokenized_train[0]['labels'])}")
print(f"   Input IDs (first 20): {tokenized_train[0]['input_ids'][:20]}")
print(f"   Label IDs (first 20): {tokenized_train[0]['labels'][:20]}")

# ============================================================
# Data Collator
# ============================================================

print("\n" + "-" * 60)
print("Setting up data collator...")
print("-" * 60)

# Data collator handles dynamic padding
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    max_length=MAX_INPUT_LENGTH
)

print(f"✅ Data collator ready")
print(f"   Will pad batches dynamically during training")

DATA LOADING AND PRE-PROCESSING

Loading datasets...
------------------------------------------------------------
Training examples:   8559
Validation examples: 1034
Test examples:       100

Sample training example:
   Question: How many heads of the departments are older than 56 ?
   Gold SQL: SELECT count(*) FROM head WHERE age  >  56;
   Schema (first 100 chars): Database: department_management
Tables:
- department(Department_ID*, Name, Creation, Ranking, Budget...

------------------------------------------------------------
Converting to HuggingFace Dataset format...
------------------------------------------------------------
✅ Datasets converted
   Train dataset: 8559 examples
   Val dataset:   1034 examples

------------------------------------------------------------
Tokenizing datasets (this may take a few minutes)...
------------------------------------------------------------


Tokenizing training data:   0%|          | 0/8559 [00:00<?, ? examples/s]

Tokenizing validation data:   0%|          | 0/1034 [00:00<?, ? examples/s]

✅ Tokenization complete
   Tokenized train: 8559 examples
   Tokenized val:   1034 examples

Tokenized example (first training sample):
   Input IDs length:  134
   Label IDs length:  17
   Input IDs (first 20): [11860, 10, 571, 186, 7701, 13, 8, 10521, 33, 2749, 145, 11526, 3, 58, 10248, 51, 9, 10, 20230, 10]
   Label IDs (first 20): [3, 23143, 14196, 3476, 599, 1935, 61, 21680, 819, 549, 17444, 427, 1246, 2490, 11526, 117, 1]

------------------------------------------------------------
Setting up data collator...
------------------------------------------------------------
✅ Data collator ready
   Will pad batches dynamically during training


## Fine Tuning the Model

In [5]:
print("=" * 60)
print("FINE-TUNING SETUP")
print("=" * 60)

# ============================================================
# Training Arguments
# ============================================================

print("\nSetting up training arguments...")
print("-" * 60)

training_args = Seq2SeqTrainingArguments(
    output_dir=str(OUTPUT_DIR),

    # Training hyperparameters
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_steps=WARMUP_STEPS,

    # Evaluation and logging
    eval_strategy="steps",
    eval_steps=EVAL_STEPS,
    save_strategy="steps",
    save_steps=SAVE_STEPS,
    logging_steps=LOGGING_STEPS,

    # Generation settings for evaluation
    predict_with_generate=True,
    generation_max_length=GEN_MAX_LENGTH,
    generation_num_beams=GEN_NUM_BEAMS,

    # Performance optimizations
    # # IN COLAB: Use these settings for GPU
    # fp16=FP16,
    # dataloader_num_workers=2,

    # # ON LAPTOP: Use these settings instead
    fp16=False,
    dataloader_num_workers=0,

    # Model saving
    save_total_limit=2,  # Only keep 2 best checkpoints
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    # Other settings
    report_to="none",  # Disable wandb/tensorboard
    seed=SEED,
)

print(f"✅ Training arguments configured")
print(f"   Output directory: {OUTPUT_DIR}")
print(f"   Total epochs: {NUM_EPOCHS}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Learning rate: {LEARNING_RATE}")

# ============================================================
# Initialize Trainer
# ============================================================

print("\n" + "-" * 60)
print("Initializing Seq2SeqTrainer...")
print("-" * 60)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print(f"✅ Trainer initialized")
print(f"   Training samples: {len(tokenized_train)}")
print(f"   Validation samples: {len(tokenized_val)}")
print(f"   Total training steps: {trainer.state.max_steps if hasattr(trainer.state, 'max_steps') else 'calculating...'}")

# ============================================================
# Start Fine-Tuning
# ============================================================

print("\n" + "=" * 60)
print("STARTING FINE-TUNING")
print("=" * 60)
print("\nThis will take approximately:")
# IN COLAB: ~15-30 minutes on GPU
print("   - On GPU (Colab): ~15-30 minutes")
# ON LAPTOP: ~3-5 hours on CPU/MPS
print("   - On CPU/MPS: ~3-5 hours")
print("\nTraining progress will be displayed below...")
print("=" * 60)

# Start training
train_result = trainer.train()

print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)

# Print training metrics
print(f"\nTraining Metrics:")
print(f"   Total runtime: {train_result.metrics['train_runtime']:.2f} seconds")
print(f"   Samples per second: {train_result.metrics['train_samples_per_second']:.2f}")
print(f"   Final train loss: {train_result.metrics['train_loss']:.4f}")

# ============================================================
# Evaluate on Validation Set
# ============================================================

print("\n" + "-" * 60)
print("Evaluating on validation set...")
print("-" * 60)

eval_result = trainer.evaluate()

print(f"✅ Evaluation complete")
print(f"\nValidation Metrics:")
print(f"   Eval loss: {eval_result['eval_loss']:.4f}")
print(f"   Eval runtime: {eval_result['eval_runtime']:.2f} seconds")
print(f"   Samples per second: {eval_result['eval_samples_per_second']:.2f}")

FINE-TUNING SETUP

Setting up training arguments...
------------------------------------------------------------
✅ Training arguments configured
   Output directory: finetuned_flant5
   Total epochs: 4
   Batch size: 8
   Learning rate: 2e-05

------------------------------------------------------------
Initializing Seq2SeqTrainer...
------------------------------------------------------------
✅ Trainer initialized
   Training samples: 8559
   Validation samples: 1034
   Total training steps: 0

STARTING FINE-TUNING

This will take approximately:
   - On GPU (Colab): ~15-30 minutes
   - On CPU/MPS: ~3-5 hours

Training progress will be displayed below...


  trainer = Seq2SeqTrainer(


Step,Training Loss,Validation Loss
500,0.2158,0.387498
1000,0.1676,0.399717
1500,0.1431,0.393052
2000,0.1148,0.405055
2500,0.1268,0.403906
3000,0.1138,0.413282
3500,0.1134,0.410019
4000,0.1099,0.410447


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].



TRAINING COMPLETE!

Training Metrics:
   Total runtime: 1405.97 seconds
   Samples per second: 24.35
   Final train loss: 0.1669

------------------------------------------------------------
Evaluating on validation set...
------------------------------------------------------------




✅ Evaluation complete

Validation Metrics:
   Eval loss: 0.3875
   Eval runtime: 7.18 seconds
   Samples per second: 143.92


## Saving the Model

In [6]:
print("=" * 60)
print("SAVING FINE-TUNED MODEL")
print("=" * 60)

# ============================================================
# Save Model and Tokenizer
# ============================================================

print("\nSaving model and tokenizer...")
print("-" * 60)

# The trainer already saved checkpoints during training
# Now we'll save the final best model explicitly

FINAL_MODEL_DIR = OUTPUT_DIR / "final_model"
FINAL_MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Save model
trainer.save_model(str(FINAL_MODEL_DIR))

# Save tokenizer
tokenizer.save_pretrained(str(FINAL_MODEL_DIR))

print(f"✅ Model saved to: {FINAL_MODEL_DIR}")
print(f"✅ Tokenizer saved to: {FINAL_MODEL_DIR}")

# ============================================================
# Save Training Configuration
# ============================================================

print("\n" + "-" * 60)
print("Saving training configuration...")
print("-" * 60)

training_config = {
    "base_model": BASE_MODEL_NAME,
    "num_epochs": NUM_EPOCHS,
    "batch_size": BATCH_SIZE,
    "learning_rate": LEARNING_RATE,
    "warmup_steps": WARMUP_STEPS,
    "weight_decay": WEIGHT_DECAY,
    "max_input_length": MAX_INPUT_LENGTH,
    "max_target_length": MAX_TARGET_LENGTH,
    "gen_max_length": GEN_MAX_LENGTH,
    "gen_num_beams": GEN_NUM_BEAMS,
    "seed": SEED,
    "final_train_loss": train_result.metrics.get('train_loss', 'N/A'),
    "final_eval_loss": eval_result.get('eval_loss', 'N/A'),
    "training_runtime_seconds": train_result.metrics.get('train_runtime', 'N/A'),
}

config_path = OUTPUT_DIR / "training_config.json"
with open(config_path, "w", encoding="utf-8") as f:
    json.dump(training_config, f, indent=2, ensure_ascii=False)

print(f"✅ Training config saved to: {config_path}")

# ============================================================
# Model Files Summary
# ============================================================

print("\n" + "-" * 60)
print("Saved Files Summary:")
print("-" * 60)

print(f"\nFinal model directory: {FINAL_MODEL_DIR}")
print(f"  Contains:")
print(f"  - pytorch_model.bin (model weights)")
print(f"  - config.json (model configuration)")
print(f"  - tokenizer files")

print(f"\nCheckpoints directory: {OUTPUT_DIR}")
print(f"  Contains intermediate checkpoints from training")

print(f"\nTraining config: {config_path}")

print("\n✅ All files saved successfully!")

SAVING FINE-TUNED MODEL

Saving model and tokenizer...
------------------------------------------------------------
✅ Model saved to: finetuned_flant5/final_model
✅ Tokenizer saved to: finetuned_flant5/final_model

------------------------------------------------------------
Saving training configuration...
------------------------------------------------------------
✅ Training config saved to: finetuned_flant5/training_config.json

------------------------------------------------------------
Saved Files Summary:
------------------------------------------------------------

Final model directory: finetuned_flant5/final_model
  Contains:
  - pytorch_model.bin (model weights)
  - config.json (model configuration)
  - tokenizer files

Checkpoints directory: finetuned_flant5
  Contains intermediate checkpoints from training

Training config: finetuned_flant5/training_config.json

✅ All files saved successfully!


## Evaluation

In [7]:
print("=" * 60)
print("EVALUATION ON HOSPITAL_1")
print("=" * 60)

# ============================================================
# Load Fine-tuned Model (if needed)
# ============================================================

print("\nLoading fine-tuned model for evaluation...")
print("-" * 60)

# If you're continuing from training, the model is already loaded
# If you're starting fresh, uncomment these lines:
# eval_model = AutoModelForSeq2SeqLM.from_pretrained(str(FINAL_MODEL_DIR))
# eval_tokenizer = AutoTokenizer.from_pretrained(str(FINAL_MODEL_DIR))
# eval_model = eval_model.to(DEVICE)

# For now, we'll use the trainer's model
eval_model = trainer.model
eval_tokenizer = tokenizer

print(f"✅ Model ready for evaluation")

# ============================================================
# SQL Utilities
# ============================================================

def canonical_sql(sql_text):
    """Normalize SQL to canonical form using sqlglot."""
    if not sql_text:
        return None
    try:
        ast = parse_one(sql_text, read="sqlite")
        return ast.sql(dialect="sqlite", pretty=False)
    except Exception:
        return None


def try_execute(conn, sql_text):
    """Execute SQL query and return result set."""
    try:
        cur = conn.execute(sql_text)
        rows = cur.fetchall()

        # Normalize floats
        normalized = []
        for row in rows:
            norm_row = []
            for val in row:
                if isinstance(val, float):
                    norm_row.append(round(val, 6))
                else:
                    norm_row.append(val)
            normalized.append(tuple(norm_row))

        return set(normalized), None

    except Exception as e:
        return None, str(e)


def extract_sql(text):
    """Extract SQL from model output."""
    text = text.strip()

    # Remove markdown code blocks if present
    if "```" in text:
        parts = text.split("```")
        for part in parts:
            if "select" in part.lower() or "SELECT" in part:
                text = part.strip()
                if text.lower().startswith("sql"):
                    text = text[3:].strip()
                break

    # Remove common prefixes
    for prefix in ["sql:", "answer:", "query:"]:
        if text.lower().startswith(prefix):
            text = text[len(prefix):].strip()

    # Ensure semicolon
    if ";" in text:
        text = text.split(";", 1)[0] + ";"

    return text.strip()

print("\n✅ SQL utilities defined")

# ============================================================
# Connect to Database
# ============================================================

print("\n" + "-" * 60)
print("Connecting to database...")
print("-" * 60)

conn = sqlite3.connect(str(SQLITE_DB))
conn.execute("PRAGMA foreign_keys=ON")

print(f"✅ Connected to: {SQLITE_DB}")

# ============================================================
# Evaluation Loop
# ============================================================

print("\n" + "=" * 60)
print("RUNNING EVALUATION")
print("=" * 60)

results = []
n_examples = len(test_data)

em_count = 0
ex_count = 0
valid_count = 0
latencies = []

print(f"\nEvaluating on {n_examples} examples from hospital_1...")
print("-" * 60)

eval_model.eval()

for i, example in enumerate(test_data, 1):
    question = example['question']
    gold_sql = example['gold_query']
    schema = example['schema_serialized']

    # Build prompt
    prompt = PROMPT_TEMPLATE.format(question=question, schema=schema)

    # Tokenize
    inputs = eval_tokenizer(
        prompt,
        return_tensors="pt",
        max_length=MAX_INPUT_LENGTH,
        truncation=True
    ).to(DEVICE)

    # Generate SQL
    start_time = time.time()

    with torch.no_grad():
        outputs = eval_model.generate(
            **inputs,
            max_length=GEN_MAX_LENGTH,
            num_beams=GEN_NUM_BEAMS,
            temperature=GEN_TEMPERATURE if GEN_TEMPERATURE > 0 else 1.0,
            do_sample=False
        )

    gen_time_ms = (time.time() - start_time) * 1000.0
    latencies.append(gen_time_ms)

    # Decode
    pred_sql_raw = eval_tokenizer.decode(outputs[0], skip_special_tokens=True)
    pred_sql_raw = extract_sql(pred_sql_raw)

    # Normalize
    pred_sql_norm = canonical_sql(pred_sql_raw)
    gold_sql_norm = canonical_sql(gold_sql)

    # ============================================================
    # Compute Metrics
    # ============================================================

    # Exact Match (EM)
    em = int(
        pred_sql_norm is not None and
        gold_sql_norm is not None and
        pred_sql_norm == gold_sql_norm
    )

    # Execution Accuracy (EX) and Valid SQL
    valid = 0
    ex_ok = 0
    error = None

    if pred_sql_norm is not None:
        # Try to execute predicted SQL
        pred_rows, error = try_execute(conn, pred_sql_norm)

        if pred_rows is not None:
            valid = 1  # SQL is valid

            # Execute gold SQL
            gold_rows, gold_error = try_execute(conn, gold_sql_norm or gold_sql)

            if gold_rows is not None:
                # Compare result sets
                ex_ok = int(pred_rows == gold_rows)
            else:
                error = f"Gold SQL failed: {gold_error}"
    else:
        error = "ParseError: Could not parse predicted SQL"

    # Update counters
    em_count += em
    ex_count += ex_ok
    valid_count += valid

    # Store result
    results.append({
        "id": example.get("id", f"test_{i}"),
        "question": question,
        "gold_sql": gold_sql,
        "pred_sql_raw": pred_sql_raw,
        "pred_sql_norm": pred_sql_norm or "",
        "em": em,
        "ex": ex_ok,
        "valid_sql": valid,
        "latency_ms": round(gen_time_ms, 2),
        "error": error or ""
    })

    # Progress update
    if i % 10 == 0 or i == n_examples:
        print(f"[{i}/{n_examples}] EM={em_count/i:.3f} EX={ex_count/i:.3f} Valid={valid_count/i:.3f}")

# ============================================================
# Save Results
# ============================================================

print("\n" + "-" * 60)
print("Saving results...")
print("-" * 60)

with open(RESULTS_CSV, "w", newline="", encoding="utf-8") as f:
    if results:
        writer = csv.DictWriter(f, fieldnames=list(results[0].keys()))
        writer.writeheader()
        writer.writerows(results)

print(f"✅ Results saved to: {RESULTS_CSV}")

# ============================================================
# Summary Statistics
# ============================================================

em_rate = em_count / n_examples
ex_rate = ex_count / n_examples
valid_rate = valid_count / n_examples
median_latency = sorted(latencies)[len(latencies) // 2] if latencies else 0

print("\n" + "=" * 60)
print("EVALUATION SUMMARY")
print("=" * 60)
print(f"\nModel: Fine-tuned Flan-T5")
print(f"Base Model: {BASE_MODEL_NAME}")
print(f"Test Dataset: hospital_1")
print(f"Examples: {n_examples}")
print(f"\nMetrics:")
print(f"  Exact Match (EM):        {em_rate:.3%} ({em_count}/{n_examples})")
print(f"  Execution Accuracy (EX): {ex_rate:.3%} ({ex_count}/{n_examples})")
print(f"  Valid-SQL rate:          {valid_rate:.3%} ({valid_count}/{n_examples})")
print(f"\nPerformance:")
print(f"  Median generation time:  {median_latency:.1f} ms")
print(f"\nBaseline Comparison (Flan-T5 before fine-tuning):")
print(f"  Baseline EM:  15.0%")
print(f"  Baseline EX:  28.0%")
print(f"  Baseline Valid: 38.0%")
print(f"\nImprovement:")
print(f"  EM improvement:    {(em_rate - 0.15) * 100:+.1f} percentage points")
print(f"  EX improvement:    {(ex_rate - 0.28) * 100:+.1f} percentage points")
print(f"  Valid improvement: {(valid_rate - 0.38) * 100:+.1f} percentage points")
print(f"\nResults saved to: {RESULTS_CSV}")
print("=" * 60)

# Close database connection
conn.close()

EVALUATION ON HOSPITAL_1

Loading fine-tuned model for evaluation...
------------------------------------------------------------
✅ Model ready for evaluation

✅ SQL utilities defined

------------------------------------------------------------
Connecting to database...
------------------------------------------------------------
✅ Connected to: hospital_1.sqlite

RUNNING EVALUATION

Evaluating on 100 examples from hospital_1...
------------------------------------------------------------
[10/100] EM=0.200 EX=0.600 Valid=0.700
[20/100] EM=0.250 EX=0.700 Valid=0.750
[30/100] EM=0.367 EX=0.667 Valid=0.767
[40/100] EM=0.450 EX=0.725 Valid=0.825
[50/100] EM=0.440 EX=0.740 Valid=0.860
[60/100] EM=0.383 EX=0.683 Valid=0.850
[70/100] EM=0.414 EX=0.671 Valid=0.843
[80/100] EM=0.438 EX=0.662 Valid=0.838
[90/100] EM=0.389 EX=0.611 Valid=0.767
[100/100] EM=0.370 EX=0.610 Valid=0.770

------------------------------------------------------------
Saving results...
----------------------------------