In [None]:
# T5 Text-to-SQL Fine-Tuning Script Optimized for Google Colab
# For bachelor thesis on schema-enhanced Text-to-SQL generation
# VERSION: Disabled fp16 to debug NaN loss

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU device: {gpu_name}")
    print(f"Memory: {gpu_memory:.2f} GB")
    if 'A100' not in gpu_name and 'H100' not in gpu_name: # Basic check
         print("Warning: T5-Large is memory intensive. Ensure sufficient GPU RAM.")

# Show GPU info
!nvidia-smi

# Install required packages
print("Installing required packages...")
!pip install -q datasets transformers evaluate tensorboard accelerate huggingface-hub pandas
print("Packages installed.")

# --- Verify Installation (Optional) ---
print("\nVerifying package versions...")
!pip show datasets transformers evaluate tensorboard accelerate huggingface-hub torch pandas
print("-" * 30)

import json
import os
import pandas as pd
import time
import re
from typing import Dict, List, Any
import numpy as np
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed
)

# Set seed for reproducibility
set_seed(42)

# --- Configuration ---
EXPERIMENT_NAME = "t5_large_sql_types_schema_v5"
SCHEMA_FORMAT = "sql"
MODEL_SIZE = "large"
EPOCHS = 25
LEARNING_RATE = 1e-5
BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 8
WEIGHT_DECAY = 0.01
MAX_INPUT_LENGTH = 1024
MAX_TARGET_LENGTH = 256
WARMUP_RATIO = 0.1
MAX_GRAD_NORM = 1.0
RESUME_FROM_CHECKPOINT = False

# Google Drive paths
DRIVE_BASE_DIR = "/content/drive/MyDrive/text2sql"
DRIVE_OUTPUT_DIR = f"{DRIVE_BASE_DIR}/{EXPERIMENT_NAME}"
DRIVE_DATASET_SOURCE_DIR = f"{DRIVE_BASE_DIR}/datasets/spider"
DRIVE_LOGS_DIR = f"{DRIVE_BASE_DIR}/logs/{EXPERIMENT_NAME}"

# Local paths
LOCAL_DATASET_DIR = "/content/datasets/spider"

# Create directories
os.makedirs(DRIVE_BASE_DIR, exist_ok=True)
os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
os.makedirs(DRIVE_LOGS_DIR, exist_ok=True)
os.makedirs(LOCAL_DATASET_DIR, exist_ok=True)

print(f"--- Running Experiment: {EXPERIMENT_NAME} ---")
print(f"Model: t5-{MODEL_SIZE}")
print(f"Schema Format: {SCHEMA_FORMAT} (with Types)")
print(f"Epochs: {EPOCHS}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Per Device Batch Size: {BATCH_SIZE}")
print(f"Gradient Accumulation Steps: {GRADIENT_ACCUMULATION_STEPS}")
print(f"Effective Batch Size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"Weight Decay: {WEIGHT_DECAY}")
print(f"Gradient Clipping: {MAX_GRAD_NORM}")
print(f"Max Input Length: {MAX_INPUT_LENGTH}")
print(f"Max Target Length: {MAX_TARGET_LENGTH}")
print(f"Warmup Ratio: {WARMUP_RATIO}")
print(f"FP16 Enabled: False")

# --- Schema Utilities ---
def load_tables_json(tables_path):
    """Load the tables.json file containing schema information."""
    full_path = os.path.join(LOCAL_DATASET_DIR, tables_path)
    print(f"Loading tables.json from: {full_path}")
    try:
        with open(full_path, 'r', encoding='utf-8') as f:
            tables_data = json.load(f)
    except FileNotFoundError:
        print(f"Error: tables.json not found at {full_path}")
        raise
    db_schemas = {db_info['db_id']: db_info for db_info in tables_data}
    return db_schemas

def get_sql_schema_string(db_id, db_schemas):
    """Create SQL schema string including types, PKs, FKs."""
    if db_id not in db_schemas: raise ValueError(f"DB ID '{db_id}' not found")
    schema_info = db_schemas[db_id]
    tables = schema_info['table_names_original']
    columns = schema_info['column_names_original']
    column_types = schema_info['column_types']
    primary_keys = set(schema_info.get('primary_keys', []))
    fk_dict = {}
    if isinstance(schema_info.get('foreign_keys'), list):
        for fk_pair in schema_info['foreign_keys']:
             if isinstance(fk_pair, (list, tuple)) and len(fk_pair) == 2:
                 col1_idx, col2_idx = fk_pair
                 if isinstance(col1_idx, int) and isinstance(col2_idx, int): fk_dict[col1_idx] = col2_idx
    table_defs = []
    for i, table in enumerate(tables):
        table_columns = []
        for col_idx, (tab_idx, col_name) in enumerate(columns):
            if tab_idx == i:
                col_type = column_types[col_idx].upper()
                col_info = f"{col_name} ({col_type})"
                if col_idx in primary_keys: col_info += " (PRIMARY KEY)"
                if col_idx in fk_dict:
                    ref_col_idx = fk_dict[col_idx]
                    if 0 <= ref_col_idx < len(columns):
                         ref_tab_idx, ref_col_name = columns[ref_col_idx]
                         if 0 <= ref_tab_idx < len(tables):
                              ref_table = tables[ref_tab_idx]
                              col_info += f" (FOREIGN KEY -> {ref_table}.{ref_col_name})"
                table_columns.append(col_info)
        if table_columns:
            table_columns.sort()
            table_def = f"Table: {table}\\nColumns: {', '.join(table_columns)}"
            table_defs.append(table_def)
    table_defs.sort()
    return "\\n".join(table_defs)

def get_compact_schema_string(db_id, db_schemas):
    """Create compact schema string."""
    if db_id not in db_schemas: raise ValueError(f"DB ID '{db_id}' not found")
    schema_info = db_schemas[db_id]
    tables = schema_info['table_names_original']
    columns = schema_info['column_names_original']
    table_columns = {}
    for i, table in enumerate(tables):
        cols = []
        for tab_idx, col_name in columns:
            if tab_idx == i: cols.append(col_name)
        if cols:
             cols.sort()
             table_columns[table] = cols
    parts = []
    for table in sorted(table_columns.keys()):
        cols = table_columns[table]
        part = f"{table}({', '.join(cols)})"
        parts.append(part)
    return " ".join(parts)

def enhance_prompts_with_schema(data_df, db_schemas, schema_format="sql"):
    """Enhance input prompts with schema information."""
    enhanced_rows = []
    skipped_count = 0
    total_count = len(data_df)
    print_interval = max(1, total_count // 10)
    print(f"Enhancing {total_count} prompts...")
    for index, example in data_df.iterrows():
        if index > 0 and index % print_interval == 0: print(f"  Processed {index}/{total_count} examples...")
        db_id = example['db_id']
        try:
            if schema_format == "compact":
                schema_str = get_compact_schema_string(db_id, db_schemas)
                input_text = f"translate English to SQL: {example['question']} | database: {db_id} | schema: {schema_str}"
            elif schema_format == "sql":
                schema_str = get_sql_schema_string(db_id, db_schemas)
                input_text = f"translate English to SQL: {example['question']} | database: {db_id} | schema:\\n{schema_str}"
            elif schema_format == "both":
                compact_str = get_compact_schema_string(db_id, db_schemas)
                sql_str = get_sql_schema_string(db_id, db_schemas)
                input_text = f"translate English to SQL: {example['question']} | database: {db_id} | schema: {compact_str}\\nDetailed schema:\\n{sql_str}"
            else: input_text = f"translate English to SQL: {example['question']} | database: {db_id}"
            output_text = example['query']
            enhanced_rows.append({"input_text": input_text, "output_text": output_text})
        except Exception as e:
             print(f"Warning: Skipping example for db_id '{db_id}': {e}")
             skipped_count += 1
    print(f"  Processed {total_count}/{total_count} examples...")
    if skipped_count > 0: print(f"Skipped {skipped_count} examples.")
    if not enhanced_rows: print("Warning: No rows enhanced.")
    return pd.DataFrame(enhanced_rows)

# --- Setup Local Dataset from Drive ---
def setup_local_dataset_from_drive():
    """Copy the Spider dataset JSON files from Google Drive to local storage."""
    print(f"\n--- Setting up Dataset ---")
    print(f"Copying dataset from Google Drive path: {DRIVE_DATASET_SOURCE_DIR}")
    drive_tables_path = f"{DRIVE_DATASET_SOURCE_DIR}/tables.json"
    drive_train_path = f"{DRIVE_DATASET_SOURCE_DIR}/train_spider.json"
    drive_dev_path = f"{DRIVE_DATASET_SOURCE_DIR}/dev.json"
    all_paths_exist = all(os.path.exists(p) for p in [drive_tables_path, drive_train_path, drive_dev_path])
    if all_paths_exist:
        print("Copying dataset files to local Colab storage...")
        try:
            !cp -v "{drive_tables_path}" "{LOCAL_DATASET_DIR}/"
            !cp -v "{drive_train_path}" "{LOCAL_DATASET_DIR}/"
            !cp -v "{drive_dev_path}" "{LOCAL_DATASET_DIR}/"
            print("Dataset files copied successfully.")
        except Exception as e:
             print(f"Error during file copy: {e}")
             raise
    else:
        missing = [p for p in [drive_tables_path, drive_train_path, drive_dev_path] if not os.path.exists(p)]
        raise FileNotFoundError(f"Missing files in Drive: {missing}")

# --- Run Setup Function ---
setup_local_dataset_from_drive()

# --- Load Database Schemas ---
print("\nLoading database schemas...")
db_schemas = load_tables_json('tables.json')
print(f"Loaded schemas for {len(db_schemas)} databases.")

# --- Load and Prepare Datasets ---
def load_and_prepare_data(file_path, db_schemas_dict, schema_fmt):
    """Load, enhance and prepare dataset."""
    actual_file_path = os.path.join(LOCAL_DATASET_DIR, file_path)
    print(f"Loading data from: {actual_file_path}")
    try:
        with open(actual_file_path, 'r', encoding='utf-8') as f:
            spider_data = json.load(f)
        df = pd.DataFrame(spider_data)
        t5_data_df = enhance_prompts_with_schema(df, db_schemas_dict, schema_format=schema_fmt)
        if t5_data_df is None or t5_data_df.empty: return None
        dataset = Dataset.from_pandas(t5_data_df)
        print(f"Prepared {len(dataset)} examples from {actual_file_path}.")
        return dataset
    except Exception as e:
        print(f"Error processing data from {actual_file_path}: {e}")
        raise
print("Loading datasets...")
train_dataset = load_and_prepare_data('train_spider.json', db_schemas, SCHEMA_FORMAT)
dev_dataset = load_and_prepare_data('dev.json', db_schemas, SCHEMA_FORMAT)
if train_dataset is None or dev_dataset is None:
     raise ValueError("Failed to load train or dev dataset.")

# Log some examples to verify prompts
print("\nSample Prompts (with types):")
for i in range(min(2, len(train_dataset))):
    print(f"\n--- Example {i+1} ---")
    print(f"Input length: {len(train_dataset[i]['input_text'])} characters")
    print(f"Input: {train_dataset[i]['input_text'][:600]}...") # Show more
    print(f"Output: {train_dataset[i]['output_text']}")

# --- Load Model and Tokenizer ---
model_name = f"t5-{MODEL_SIZE}"
print(f"\nLoading model: {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    return_dict=True
)
print("Model and tokenizer loaded.")

# --- Tokenization Function ---
def tokenize_function(examples):
    """Tokenizes input and target text."""
    input_texts = [text if text is not None else "" for text in examples['input_text']]
    output_texts = [text if text is not None else "" for text in examples['output_text']]
    model_inputs = tokenizer(
        input_texts, max_length=MAX_INPUT_LENGTH, truncation=True,
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            output_texts, max_length=MAX_TARGET_LENGTH, truncation=True,
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# --- Tokenize the Datasets ---
print("\nTokenizing datasets...")
tokenized_train = train_dataset.map(
    tokenize_function, batched=True, remove_columns=train_dataset.column_names, desc="Tokenizing training dataset"
)
tokenized_dev = dev_dataset.map(
    tokenize_function, batched=True, remove_columns=dev_dataset.column_names, desc="Tokenizing development dataset"
)
print(f"Training dataset tokenized: {len(tokenized_train)} examples")
print(f"Development dataset tokenized: {len(tokenized_dev)} examples")

# --- Data Collator ---
print("\nData collator initializing...")
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,
    padding="longest",
    pad_to_multiple_of=None
)
print(f"Data collator using label_pad_token_id: {data_collator.label_pad_token_id}")

# --- Training Arguments ---
print("\nConfiguring training arguments...")
total_steps = len(tokenized_train) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS) * EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)
print(f"Total training steps: {total_steps}, Warmup steps: {warmup_steps}")
training_args = Seq2SeqTrainingArguments(
    output_dir=DRIVE_OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE * 2,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_steps=warmup_steps,
    bf16=True,
    bf16_full_eval=True,
    fp16=False,
    gradient_checkpointing=True,
    max_grad_norm=MAX_GRAD_NORM,
    evaluation_strategy="epoch",
    logging_strategy="steps",
    logging_steps=100,
    logging_dir=DRIVE_LOGS_DIR,
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="tensorboard",
    optim="adafactor",
    seed=42,
)

# --- Trainer Initialization ---
print("\nInitializing Trainer...")
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_dev,
    data_collator=data_collator,
    tokenizer=tokenizer,
)
print("Trainer initialized.")

# --- Main Training Execution ---
print(f"\n--- Starting Training ---")
start_time = time.time()
checkpoint = None
if RESUME_FROM_CHECKPOINT:
    if os.path.isdir(DRIVE_OUTPUT_DIR):
        checkpoints = [os.path.join(DRIVE_OUTPUT_DIR, d) for d in os.listdir(DRIVE_OUTPUT_DIR) if d.startswith('checkpoint-') and os.path.isdir(os.path.join(DRIVE_OUTPUT_DIR, d))]
        if checkpoints:
            latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[-1]))
            print(f"Resuming from checkpoint: {latest_checkpoint}")
            checkpoint = latest_checkpoint
        else: print(f"No checkpoint found in {DRIVE_OUTPUT_DIR}.")
    else: print(f"Output directory {DRIVE_OUTPUT_DIR} does not exist.")

try:
    print("Running initial evaluation to check for NaN...")
    initial_eval = trainer.evaluate()
    print(f"Initial evaluation result: {initial_eval}")
    if 'eval_loss' in initial_eval and np.isnan(initial_eval['eval_loss']):
        print("ERROR: NaN detected in evaluation loss even with fp16=False. Check data or learning rate.")
        print("Proceeding with training despite initial NaN...")

    train_result = trainer.train(resume_from_checkpoint=checkpoint)

    # --- Post-Training Actions ---
    print("\n--- Training Finished ---")
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)

    print("\nSaving the final model...")
    trainer.save_model()
    trainer.save_state()
    tokenizer.save_pretrained(training_args.output_dir)
    print(f"Model saved to {training_args.output_dir}")

    # Generate training report
    elapsed_time = time.time() - start_time
    hours, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    report_path = os.path.join(training_args.output_dir, "training_summary.txt")
    with open(report_path, "w") as f:
        f.write(f"Experiment: {EXPERIMENT_NAME}\\n")
        f.write(f"Model: t5-{MODEL_SIZE}\\n")
        f.write(f"Schema Format: {SCHEMA_FORMAT} (with Types)\\n")
        f.write(f"FP16 Enabled: {training_args.fp16}\\n")
        f.write(f"Epochs Configured: {EPOCHS}\\n")
        f.write(f"Epochs Trained: {metrics.get('epoch', 0.0):.2f}\\n")
        f.write(f"Training Time: {int(hours)}h {int(minutes)}m {int(seconds)}s\\n")
        f.write(f"Learning Rate: {LEARNING_RATE}\\n")
        f.write(f"Batch Size (per device): {BATCH_SIZE}\\n")
        f.write(f"Gradient Accumulation Steps: {GRADIENT_ACCUMULATION_STEPS}\\n")
        f.write(f"Effective Batch Size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}\\n")
        f.write(f"Weight Decay: {WEIGHT_DECAY}\\n")
        f.write(f"Warmup Ratio: {WARMUP_RATIO} (Steps: {warmup_steps})\\n")
        f.write(f"Max Grad Norm: {MAX_GRAD_NORM}\\n")
        f.write(f"Max Input Length: {MAX_INPUT_LENGTH}\\n")
        f.write(f"Max Target Length: {MAX_TARGET_LENGTH}\\n")
        f.write("\\nTraining Metrics:\\n")
        for key, value in metrics.items():
            f.write(f"  {key}: {value}\\n")
        f.write(f"\\nModel saved to: {training_args.output_dir}\\n")
        f.write(f"To evaluate the model, use the evaluation script.\\n")
    print(f"Training summary saved to {report_path}")
    print(f"Training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")

except KeyboardInterrupt:
    print("\\nTraining interrupted by user. Saving current state...")
    interrupted_path = os.path.join(training_args.output_dir, "interrupted_checkpoint")
    if 'trainer' in locals() and hasattr(trainer, 'save_model'): trainer.save_model(interrupted_path)
    if 'tokenizer' in locals() and hasattr(tokenizer, 'save_pretrained'): tokenizer.save_pretrained(interrupted_path)
    print(f"Interrupted checkpoint potentially saved to {interrupted_path}")

except Exception as e:
    print(f"\\n--- An error occurred during training: {e}")
    import traceback
    traceback.print_exc()
    raise

print("\\n--- Script Finished ---")


Mounted at /content/drive
CUDA available: True
GPU device: NVIDIA A100-SXM4-40GB
Memory: 42.47 GB
Sun Apr  6 13:09:16 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             54W /  400W |    5119MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------

Tokenizing training dataset:   0%|          | 0/7000 [00:00<?, ? examples/s]



Tokenizing development dataset:   0%|          | 0/1034 [00:00<?, ? examples/s]

Training dataset tokenized: 7000 examples
Development dataset tokenized: 1034 examples

Data collator initializing...
Data collator using label_pad_token_id: -100

Configuring training arguments...
Total training steps: 10925, Warmup steps: 1092

Initializing Trainer...
Trainer initialized.

--- Starting Training ---
Running initial evaluation to check for NaN...


  trainer = Seq2SeqTrainer(


Exception in thread Thread-11:
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/local/lib/python3.11/dist-packages/tensorboard/summary/writer/event_file_writer.py", line 244, in run
    self._run()
  File "/usr/local/lib/python3.11/dist-packages/tensorboard/summary/writer/event_file_writer.py", line 289, in _run
    self._record_writer.flush()
  File "/usr/local/lib/python3.11/dist-packages/tensorboard/summary/writer/record_writer.py", line 43, in flush
    self._writer.flush()
  File "/usr/local/lib/python3.11/dist-packages/tensorflow/python/lib/io/file_io.py", line 221, in flush
    self._writable_file.flush()
tensorflow.python.framework.errors_impl.FailedPreconditionError: /content/drive/MyDrive/text2sql/logs/t5_large_sql_types_schema_v4/events.out.tfevents.1743944854.71614af774ea.5565.0; Transport endpoint is not connected


Initial evaluation result: {'eval_loss': 3.3948590755462646, 'eval_model_preparation_time': 1.1534, 'eval_runtime': 21.6196, 'eval_samples_per_second': 47.827, 'eval_steps_per_second': 11.98}


Epoch,Training Loss,Validation Loss,Model Preparation Time
1,3.4024,2.987877,1.1534
2,2.5041,1.91293,1.1534
3,1.4238,1.078967,1.1534
4,1.0722,0.822407,1.1534
5,0.8769,0.721675,1.1534
6,0.7693,0.66026,1.1534
7,0.7139,0.617687,1.1534
8,0.6668,0.58975,1.1534
9,0.6424,0.569253,1.1534
10,0.5964,0.555962,1.1534


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].



--- Training Finished ---
***** train metrics *****
  epoch                    =      24.944
  total_flos               = 444282020GF
  train_loss               =        0.86
  train_runtime            =  9:29:51.99
  train_samples_per_second =       5.118
  train_steps_per_second   =        0.32

Saving the final model...
Model saved to /content/drive/MyDrive/text2sql/t5_large_sql_types_schema_v5
Training summary saved to /content/drive/MyDrive/text2sql/t5_large_sql_types_schema_v5/training_summary.txt
Training completed in 9h 30m 18s
\n--- Script Finished ---
