# 03 - Model Training

This notebook fine-tunes FLAN-T5 on the SoQG dataset for Socratic question generation.

## Objectives
- Load preprocessed tokenized dataset
- Configure Seq2SeqTrainer with appropriate hyperparameters
- Train FLAN-T5-base model
- Save checkpoints and final model
- Log training metrics

## Hardware Requirements
- **Minimum**: T4 GPU (16GB VRAM) - Google Colab Pro
- **Recommended**: A10G (24GB VRAM) - Kaggle
- **Training Time**: ~6 hours for flan-t5-base, 3 epochs

## Hyperparameters (from EACL 2023 paper)
- Learning rate: 5e-5
- Batch size: 4 (with gradient accumulation = 4 â†’ effective 16)
- Epochs: 3-5
- FP16: Enabled for memory efficiency

## 1. Environment Check and Setup

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
import os
import json
from pathlib import Path
from datetime import datetime

from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback
)
from datasets import load_from_disk
import numpy as np

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

## 2. Configuration

In [None]:
MODEL_NAME = "google/flan-t5-base"
DATASET_PATH = Path("../datasets/processed/soqg_tokenized")
OUTPUT_DIR = Path("../backend/model_artifacts/soqg_flan_t5")
LOGS_DIR = Path("../experiments/logs")

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)

config = {
    "model_name": MODEL_NAME,
    "learning_rate": 5e-5,
    "per_device_train_batch_size": 4,
    "per_device_eval_batch_size": 4,
    "gradient_accumulation_steps": 4,
    "num_train_epochs": 3,
    "warmup_ratio": 0.1,
    "weight_decay": 0.01,
    "fp16": torch.cuda.is_available(),
    "seed": SEED,
    "max_source_length": 512,
    "max_target_length": 128,
}

print("Training Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 3. Load Dataset

In [None]:
dataset = load_from_disk(DATASET_PATH)
print(f"Dataset loaded:")
print(f"  Train: {len(dataset['train'])} samples")
print(f"  Validation: {len(dataset['validation'])} samples")
print(f"  Test: {len(dataset['test'])} samples")

In [None]:
print("\nDataset features:")
print(dataset['train'].features)

print("\nSample input_ids shape:", len(dataset['train'][0]['input_ids']))
print("Sample labels shape:", len(dataset['train'][0]['labels']))

## 4. Load Model and Tokenizer

In [None]:
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)

print(f"Model parameters: {model.num_parameters():,}")
print(f"Model dtype: {model.dtype}")

## 5. Data Collator

The data collator handles dynamic padding and prepares batches for the model.

In [None]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=8
)

## 6. Define Training Arguments

In [None]:
run_name = f"soqg_flan_t5_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

training_args = Seq2SeqTrainingArguments(
    output_dir=str(OUTPUT_DIR / "checkpoints"),
    run_name=run_name,
    
    num_train_epochs=config["num_train_epochs"],
    per_device_train_batch_size=config["per_device_train_batch_size"],
    per_device_eval_batch_size=config["per_device_eval_batch_size"],
    gradient_accumulation_steps=config["gradient_accumulation_steps"],
    
    learning_rate=config["learning_rate"],
    warmup_ratio=config["warmup_ratio"],
    weight_decay=config["weight_decay"],
    
    fp16=config["fp16"],
    
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    
    logging_dir=str(LOGS_DIR / run_name),
    logging_steps=100,
    report_to=["tensorboard"],
    
    predict_with_generate=True,
    generation_max_length=config["max_target_length"],
    
    seed=config["seed"],
    dataloader_num_workers=2,
)

print(f"Effective batch size: {config['per_device_train_batch_size'] * config['gradient_accumulation_steps']}")
print(f"Training steps per epoch: {len(dataset['train']) // (config['per_device_train_batch_size'] * config['gradient_accumulation_steps'])}")

## 7. Initialize Trainer

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

print("Trainer initialized.")

## 8. Train Model

This will take approximately 6 hours on a T4 GPU for 3 epochs.

In [None]:
print(f"Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Run name: {run_name}")
print("="*50)

train_result = trainer.train()

print("="*50)
print(f"Training completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

## 9. Save Final Model

In [None]:
trainer.save_model(str(OUTPUT_DIR / "final"))
tokenizer.save_pretrained(str(OUTPUT_DIR / "final"))

print(f"Model saved to {OUTPUT_DIR / 'final'}")

## 10. Log Training Metrics

In [None]:
train_metrics = train_result.metrics
train_metrics["train_samples"] = len(dataset['train'])

print("\nTraining Metrics:")
for key, value in train_metrics.items():
    print(f"  {key}: {value}")

In [None]:
eval_result = trainer.evaluate(dataset['validation'])

print("\nValidation Metrics:")
for key, value in eval_result.items():
    print(f"  {key}: {value}")

## 11. Save Experiment Log

In [None]:
experiment_log = {
    "run_name": run_name,
    "timestamp": datetime.now().isoformat(),
    "config": config,
    "train_metrics": train_metrics,
    "eval_metrics": eval_result,
    "model_path": str(OUTPUT_DIR / "final"),
    "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
}

log_path = LOGS_DIR / f"{run_name}_log.json"
with open(log_path, "w") as f:
    json.dump(experiment_log, f, indent=2, default=str)

print(f"Experiment log saved to {log_path}")

## 12. Quick Generation Test

Verify the model generates reasonable outputs before full evaluation.

In [None]:
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

test_inputs = [
    "Generate a Socratic question: clarification: The theory of evolution explains how species change over time through natural selection.",
    "Generate a Socratic question: assumptions: Democracy is the best form of government for all countries.",
    "Generate a Socratic question: reasons_evidence: Climate change is primarily caused by human activities.",
]

print("Sample Generations:")
print("="*60)

for test_input in test_inputs:
    inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    outputs = model.generate(
        **inputs,
        max_length=128,
        num_beams=4,
        early_stopping=True,
        no_repeat_ngram_size=2
    )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    q_type = test_input.split(":")[1].strip()
    
    print(f"Type: {q_type}")
    print(f"Input: {test_input.split(':')[-1][:80]}...")
    print(f"Output: {generated}")
    print("-"*60)

## 13. Training Summary

### Results
- **Final Training Loss**: See metrics above
- **Final Validation Loss**: See metrics above
- **Training Time**: Check experiment log

### Next Steps
1. **04_evaluation.ipynb** - Run full evaluation with BLEU, ROUGE, BERTScore
2. Compare metrics to paper baselines (BLEU-1: 0.172, ROUGE-L: 0.211)
3. Perform manual evaluation on 50 samples

### Model Location
The trained model is saved at:
```
../backend/model_artifacts/soqg_flan_t5/final/
```

Load it with:
```python
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained("../backend/model_artifacts/soqg_flan_t5/final")
tokenizer = T5Tokenizer.from_pretrained("../backend/model_artifacts/soqg_flan_t5/final")
```