# 03 — Fine-tuning FLAN-T5 with LoRA

Low-Rank Adaptation (LoRA; Hu et al., 2022) reduces the number of trainable parameters by approximating weight updates as low-rank matrix products. For a weight matrix $W \in \mathbb{R}^{d \times k}$, updates are represented as $\Delta W = BA$ where $B \in \mathbb{R}^{d \times r}$, $A \in \mathbb{R}^{r \times k}$, and $r \ll \min(d, k)$. With rank $r=16$ applied to the attention and feed-forward projection layers of FLAN-T5-small (77M parameters), only 1.4% of parameters are trainable — enabling fine-tuning on Apple Silicon MPS without quantisation.

Training uses `Seq2SeqTrainer` with early stopping (patience = 5 steps, monitored metric: validation ROUGE-L) and deterministic beam search for all evaluation passes (num_beams = 4, do_sample = False), matching the evaluation protocol of Ang et al. (2023).

In [None]:
# Uncomment if running on Google Colab
# from google.colab import drive
# drive.mount('/content/drive')

# import os
# DRIVE_ROOT = "/content/drive/MyDrive/socratic-path"
# os.makedirs(DRIVE_ROOT, exist_ok=True)
# print(f"Google Drive mounted at: {DRIVE_ROOT}")

In [None]:
%pip install -q peft>=0.7.0 evaluate rouge_score

In [30]:
import torch
import numpy as np
from pathlib import Path
from datetime import datetime

from datasets import load_from_disk
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback
)
from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    PeftModel
)

import evaluate

In [31]:
if torch.cuda.is_available():
    device = "cuda"
    print(f"Using device: {device}")
    gpu_props = torch.cuda.get_device_properties(0)
    if gpu_props:
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"VRAM: {gpu_props.total_memory / 1e9:.1f} GB")
elif torch.backends.mps.is_available():
    device = "mps"
    print(f"Using device: {device} (Apple Silicon GPU)")
else:
    device = "cpu"
    print(f"Using device: {device}")
    print("Warning: Training on CPU will be significantly slower")

Using device: mps (Apple Silicon GPU)


## Model Configuration

`flan-t5-small` (77M parameters) is used as the base model — it matches the trained adapter already saved to `models/flan-t5-socratic-lora/`. To retrain with `flan-t5-base` (250M), change `MODEL_NAME` and run on a CUDA GPU (Google Colab or Kaggle).

In [None]:
DRIVE_ROOT = ".."  # Change to DRIVE_ROOT if using Colab
DATA_DIR = Path(DRIVE_ROOT) / "datasets/processed"
MODEL_OUTPUT_DIR = Path(DRIVE_ROOT) / "models/flan-t5-socratic-lora"
LOGS_DIR = Path(DRIVE_ROOT) / "logs"

MODEL_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)

# NOTE: The existing trained adapter (in models/flan-t5-socratic-lora/) was trained
# with flan-t5-small (274 minutes on Apple Silicon MPS). To re-train with base,
# change back to flan-t5-base and run from a CUDA GPU (Colab/Kaggle).
MODEL_NAME = "google/flan-t5-small"  # 77M params — matches the existing trained adapter
# MODEL_NAME = "google/flan-t5-base"  # 250M params — use this on Colab with GPU
# MODEL_NAME = "google/flan-t5-large" # 780M params, 12 GB VRAM with LoRA

print(f"Selected model: {MODEL_NAME}")

Selected model: google/flan-t5-base


## LoRA Configuration

Rank $r=16$ with $\alpha=32$ gives a scaling factor of 2.0, which is standard for seq2seq tasks. Both attention (`q`, `k`, `v`, `o`) and feed-forward (`wi_0`, `wi_1`, `wo`) projection layers are adapted — attention-only LoRA is insufficient for generation quality in encoder-decoder models. The `embed_tokens` and `lm_head` matrices are added to `modules_to_save` because the tokenizer vocabulary was extended by one token (`[Question]`); without saving these, reloading the adapter on an unresized base model raises a shape mismatch error.

In [None]:

# LoRA applies low-rank delta matrices to the specified projection layers.
# rank r=16, alpha=32 → effective scale = 2.0 (standard for seq2seq).
# Attention + FFN layers are both targeted; attention-only LoRA limits decoder
# generation quality in encoder-decoder models.
# modules_to_save includes embed_tokens and lm_head because the tokenizer
# alongside the adapter deltas so the load sequence in inference is valid.
LORA_CONFIG = {
    "r": 16,
    # Scaling: lora_alpha/r controls the effective learning rate of the adapter.
    # alpha=32 with r=16 gives a scale of 2.0 — standard for seq2seq tasks.
    "lora_alpha": 32,
    # Attention + feedforward layers for T5/FLAN-T5 decoder generation quality.
    "target_modules": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
    "lora_dropout": 0.1,
    "bias": "none",
    "task_type": TaskType.SEQ_2_SEQ_LM,
    # Save resized embedding layers so inference loading works after vocab resize.
    "modules_to_save": ["embed_tokens", "lm_head"],
}

print("LoRA Configuration:")
for k, v in LORA_CONFIG.items():
    if k != "task_type":
        print(f"  {k}: {v}")


## Training Configuration

Early stopping (patience = 5) prevents overfitting and controls compute. Cosine learning rate annealing is preferred over linear decay for fine-tuning transformers. The effective batch size is 16 (8 per device × 2 gradient accumulation steps). Evaluation uses deterministic beam search (num_beams = 4) throughout training so that validation ROUGE-L values are reproducible and comparable to the paper's reported scores.

In [None]:

# Cosine LR schedule is preferred over linear for fine-tuning (smoother tail).
# fp16 is gated on CUDA availability: MPS (Apple Silicon) does not support fp16
# and produces NaN loss if enabled. eval_do_sample=False ensures deterministic
# beam search during training-time evaluation, giving reproducible ROUGE-L values.
TRAINING_CONFIG = {
    "learning_rate": 1e-4,          # Standard LoRA LR for seq2seq tasks
    "per_device_train_batch_size": 8,
    "per_device_eval_batch_size": 8,
    "gradient_accumulation_steps": 2,  # Effective batch = 16
    "num_train_epochs": 10,            # Upper bound; early stopping will fire sooner
    "lr_scheduler_type": "cosine",     # Better convergence than linear for fine-tuning
    "warmup_steps": 500,
    "weight_decay": 0.01,
    "max_source_length": 400,
    "max_target_length": 80,
    "fp16": torch.cuda.is_available(),  # Only on CUDA; MPS/CPU use bf16 or fp32
    "seed": 42,
    # Evaluation generation: deterministic beam search so validation ROUGE is
    # reproducible and directly comparable to the SOQG paper's reported scores.
    # Do NOT use sampling here — sampling introduces random variance that masks
    # real learning signal and lowers scores vs paper benchmarks.
    "eval_num_beams": 4,
    "eval_do_sample": False,
}

print("Training Configuration:")
for k, v in TRAINING_CONFIG.items():
    print(f"  {k}: {v}")

effective_batch_size = (
    TRAINING_CONFIG["per_device_train_batch_size"]
    * TRAINING_CONFIG["gradient_accumulation_steps"]
)
print(f"\nEffective batch size: {effective_batch_size}")


In [35]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(TRAINING_CONFIG["seed"])

In [36]:
dataset = load_from_disk(str(DATA_DIR / "soqg_tokenized"))
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 84582
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 10573
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 10573
    })
})


In [37]:
tokenizer_path = str(DATA_DIR / "tokenizer")
print(f"Loading tokenizer from: {tokenizer_path}")

if not Path(tokenizer_path).exists():
    raise FileNotFoundError(
        f"Tokenizer not found at {tokenizer_path}. "
        "Please run 02_preprocessing.ipynb first to generate the tokenizer."
    )

tokenizer = T5Tokenizer.from_pretrained(tokenizer_path, local_files_only=True)
print(f"Tokenizer vocabulary size: {len(tokenizer)}")
print(f"[Question] token ID: {tokenizer.convert_tokens_to_ids('[Question]')}")

Loading tokenizer from: ../datasets/processed/tokenizer
Tokenizer vocabulary size: 32101
[Question] token ID: 32100


## Model Initialisation

The base model's embedding matrix is resized to the extended vocabulary (32,101 tokens) **before** the LoRA adapter is loaded. Reversing this order causes a shape mismatch (`[32,101, 512]` vs. `[32,128, 512]`), because the adapter's saved `embed_tokens` and `lm_head` weights reflect the resized dimensions.

In [38]:
# Load base model
print(f"Loading base model: {MODEL_NAME}")
base_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
base_model.resize_token_embeddings(len(tokenizer))

print(f"Base model parameters: {base_model.num_parameters():,}")

Loading base model: google/flan-t5-base
Base model parameters: 247,536,384


In [39]:
# Configure LoRA
lora_config = LoraConfig(**LORA_CONFIG)

# Wrap model with PEFT
model = get_peft_model(base_model, lora_config)

# Enable input gradients for gradient checkpointing compatibility
model.enable_input_require_grads()

# Print trainable parameters
model.print_trainable_parameters()

# Calculate reduction
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = model.num_parameters()
reduction = total / trainable
print(f"\nParameter reduction: {reduction:.0f}× fewer trainable parameters")

trainable params: 3,538,944 || all params: 251,075,328 || trainable%: 1.4095

Parameter reduction: 71× fewer trainable parameters


In [40]:
rouge_metric = evaluate.load("rouge")

In [None]:

# Corpus-level ROUGE is computed over all predictions and references in one call
# (not averaged per sentence), which is the academic standard (Lin, 2004) and
# matches the SOQG paper's evaluation setup.
rouge_metric = evaluate.load("rouge")


def compute_metrics(eval_preds):
    """Corpus-level ROUGE for Seq2SeqTrainer.

    The trainer passes raw token-id arrays; we decode them, strip padding, and
    compute ROUGE-1/2/L over the full batch at once (corpus-level).
    """
    predictions, labels = eval_preds

    # Replace -100 (ignore_index) with pad_token_id before decoding
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Strip whitespace and the [Question] prefix used in targets
    decoded_preds = [p.replace("[Question]", "").strip() for p in decoded_preds]
    decoded_labels = [l.replace("[Question]", "").strip() for l in decoded_labels]

    # Corpus-level ROUGE — one call over the entire batch
    result = rouge_metric.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True,
    )

    return {
        "rouge1": result["rouge1"],
        "rouge2": result["rouge2"],
        "rougeL": result["rougeL"],
    }


In [42]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=8
)

In [None]:

# eval_steps=500 gives frequent checkpoints — important on cloud sessions that
# may terminate mid-epoch. load_best_model_at_end=True restores the checkpoint
# with the highest validation ROUGE-L at the end of training.
run_name = f"socratic-lora-r{LORA_CONFIG['r']}-{datetime.now().strftime('%Y%m%d-%H%M')}"

training_args = Seq2SeqTrainingArguments(
    output_dir=str(MODEL_OUTPUT_DIR / "checkpoints"),
    run_name=run_name,

    # ── Epochs & batch ──────────────────────────────────────────────────────
    num_train_epochs=TRAINING_CONFIG["num_train_epochs"],
    per_device_train_batch_size=TRAINING_CONFIG["per_device_train_batch_size"],
    per_device_eval_batch_size=TRAINING_CONFIG["per_device_eval_batch_size"],
    gradient_accumulation_steps=TRAINING_CONFIG["gradient_accumulation_steps"],

    # ── Optimiser ───────────────────────────────────────────────────────────
    learning_rate=TRAINING_CONFIG["learning_rate"],
    weight_decay=TRAINING_CONFIG["weight_decay"],
    warmup_steps=TRAINING_CONFIG["warmup_steps"],
    lr_scheduler_type=TRAINING_CONFIG["lr_scheduler_type"],   # cosine

    # ── Precision ───────────────────────────────────────────────────────────
    fp16=TRAINING_CONFIG["fp16"],

    # ── Checkpointing ───────────────────────────────────────────────────────
    eval_strategy="steps",
    eval_steps=500,           # More frequent — important on Colab with timeouts
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,       # Best + 2 most-recent checkpoints

    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,

    # ── Logging ─────────────────────────────────────────────────────────────
    logging_dir=str(LOGS_DIR / run_name),
    logging_steps=100,
    report_to="tensorboard",

    # ── Generation (evaluation only) ────────────────────────────────────────
    # Deterministic beam search so validation ROUGE is reproducible and
    # comparable to the paper's Table 3. Do NOT use do_sample=True here.
    predict_with_generate=True,
    generation_max_length=TRAINING_CONFIG["max_target_length"],
    generation_num_beams=TRAINING_CONFIG["eval_num_beams"],    # 4

    # ── Reproducibility ─────────────────────────────────────────────────────
    seed=TRAINING_CONFIG["seed"],
    dataloader_num_workers=0 if device == "mps" else 2,
    dataloader_pin_memory=False if device == "mps" else True,
)

print(f"Run name: {run_name}")


In [44]:
# Optimize for Colab environment
if device == "cuda":
    # Enable TF32 for faster training on Ampere GPUs
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    print("TF32 enabled for faster training on compatible GPUs")

# Clear cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:

# patience=5 gives the model sufficient time to recover from plateau phases,
# which are common in the first few thousand LoRA fine-tuning steps.
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=5,
    early_stopping_threshold=0.001,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

print("Trainer initialised.")
print(f"  Training samples : {len(dataset['train']):,}")
print(f"  Validation samples : {len(dataset['validation']):,}")
print(f"  Early stopping patience : 5 steps")


In [24]:
print("Running baseline evaluation (before fine-tuning)...")
baseline_results = trainer.evaluate()

print("\nBaseline Metrics:")
for key, value in baseline_results.items():
    if "rouge" in key:
        print(f"  {key}: {value:.4f}")

Running baseline evaluation (before fine-tuning)...

Baseline Metrics:
  eval_rouge1: 0.2869
  eval_rouge2: 0.0728
  eval_rougeL: 0.2726


In [46]:
print("Starting LoRA training...")
print(f"Training samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")
print(f"LoRA config: r={LORA_CONFIG['r']}, α={LORA_CONFIG['lora_alpha']}, targets={LORA_CONFIG['target_modules']}")
print("-" * 50)

train_result = trainer.train()

Starting LoRA training...
Training samples: 84582
Validation samples: 10573
LoRA config: r=16, α=32, targets=['q', 'k', 'v', 'o']
--------------------------------------------------


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [17]:
print("\nTraining Complete!")
print("=" * 50)
print(f"Total training time: {train_result.metrics['train_runtime']:.1f} seconds ({train_result.metrics['train_runtime']/60:.1f} minutes)")
print(f"Training samples/second: {train_result.metrics['train_samples_per_second']:.2f}")
print(f"Final training loss: {train_result.metrics['train_loss']:.4f}")


Training Complete!
Total training time: 16450.4 seconds (274.2 minutes)
Training samples/second: 25.71
Final training loss: 3.0771


In [18]:
print("Running final evaluation...")
final_results = trainer.evaluate()

print("\nFinal Metrics:")
for key, value in final_results.items():
    print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")

Running final evaluation...



Final Metrics:
  eval_loss: 1.6739
  eval_rouge1: 0.2869
  eval_rouge2: 0.0728
  eval_rougeL: 0.2726
  eval_runtime: 925.0305
  eval_samples_per_second: 11.4300
  eval_steps_per_second: 1.4290
  epoch: 1.5132


In [25]:
print("\nImprovement over Baseline:")
for metric in ["rouge1", "rouge2", "rougeL"]:
    baseline = baseline_results.get(f"eval_{metric}", 0)
    final = final_results.get(f"eval_{metric}", 0)
    improvement = final - baseline
    print(f"  {metric}: {baseline:.4f} -> {final:.4f} ({improvement:+.4f})")


Improvement over Baseline:
  rouge1: 0.2869 -> 0.2869 (+0.0000)
  rouge2: 0.0728 -> 0.0728 (+0.0000)
  rougeL: 0.2726 -> 0.2726 (+0.0000)


## Adapter Checkpoint

The LoRA adapter (~137 MB) is saved to `models/flan-t5-socratic-lora/adapter/`. This directory contains the adapter delta weights plus the full `embed_tokens` and `lm_head` tensors (saved because `modules_to_save` was set during training). The tokenizer is co-located with the adapter so that the load sequence in notebook 04 is self-contained.

In [20]:
adapter_path = MODEL_OUTPUT_DIR / "adapter"
adapter_path.mkdir(parents=True, exist_ok=True)

model.save_pretrained(str(adapter_path))
tokenizer.save_pretrained(str(adapter_path))

print(f"✓ LoRA adapter saved to: {adapter_path}")

# Calculate adapter size
import os
adapter_size = sum(os.path.getsize(adapter_path / f) for f in os.listdir(adapter_path) if os.path.isfile(adapter_path / f)) / 1e6
print(f"  Adapter size: ~{adapter_size:.1f} MB")



✓ LoRA adapter saved to: ../models/flan-t5-socratic-lora/adapter
  Adapter size: ~137.8 MB


In [26]:
print("Merging LoRA weights into base model...")
merged_model = model.merge_and_unload()

merged_path = MODEL_OUTPUT_DIR / "merged"
merged_path.mkdir(parents=True, exist_ok=True)

merged_model.save_pretrained(str(merged_path))
tokenizer.save_pretrained(str(merged_path))

print(f"✓ Merged model saved to: {merged_path}")

Merging LoRA weights into base model...
✓ Merged model saved to: ../models/flan-t5-socratic-lora/merged


In [27]:
import json

training_summary = {
    "model_name": MODEL_NAME,
    "run_name": run_name,
    "lora_config": {k: v for k, v in LORA_CONFIG.items() if k != "task_type"},
    "training_config": TRAINING_CONFIG,
    "trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad),
    "total_params": model.num_parameters(),
    "trainable_percent": sum(p.numel() for p in model.parameters() if p.requires_grad) / model.num_parameters() * 100,
    "baseline_metrics": {k: float(v) for k, v in baseline_results.items() if isinstance(v, (int, float))},
    "final_metrics": {k: float(v) for k, v in final_results.items() if isinstance(v, (int, float))},
    "training_time_seconds": train_result.metrics['train_runtime']
}

with open(adapter_path / "training_summary.json", "w") as f:
    json.dump(training_summary, f, indent=2)

print(f"✓ Training summary saved.")

✓ Training summary saved.


In [None]:

# Two generation configs: eval_config uses deterministic beam search for
# reproducible ROUGE scores; sample_config uses stochastic sampling for the
# live demo to produce varied outputs. Only use sample_config in the frontend.
eval_config = dict(
    max_length=TRAINING_CONFIG["max_target_length"],
    num_beams=4,
    do_sample=False,
)

sample_config = dict(
    max_length=TRAINING_CONFIG["max_target_length"],
    num_beams=2,
    do_sample=True,
    temperature=0.8,
    top_p=0.9,
    repetition_penalty=1.2,
    no_repeat_ngram_size=3,
)

# Test with question-type-aware prompts (matching training format)
test_cases = [
    ("reasons_evidence",
     "Climate change is not as serious as scientists claim because "
     "the weather has always changed throughout history."),
    ("clarity",
     "Social media is making teenagers more depressed and we should "
     "ban it for anyone under 18."),
    ("implication_consequences",
     "Artificial intelligence will eventually replace all human jobs "
     "and we need to prepare for universal basic income."),
]

model.eval()
print("Sample Generations (beam search — deterministic)")
print("=" * 65)

for q_type, context in test_cases:
    input_text = (
        f"Generate a Socratic question for this context: "
        f"{q_type}: {context}"
    )
    inputs = tokenizer(input_text, return_tensors="pt",
                       max_length=400, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(**inputs, **eval_config)

    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    generated = generated.replace("[Question]", "").strip()

    print(f"Type    : {q_type}")
    print(f"Context : {context[:90]}...")
    print(f"Question: {generated}")
    print("-" * 65)


In [None]:

# Load sequence: tokenizer → base model → resize embeddings → adapter.
# Resizing before loading the adapter is required because the adapter's
# embed_tokens and lm_head were saved at the extended vocabulary size (32,101).
print("Loading tokenizer from adapter directory...")
tokenizer_inference = T5Tokenizer.from_pretrained(str(adapter_path))
print(f"  Vocab size: {len(tokenizer_inference)}")

print(f"\nLoading base model: {MODEL_NAME}")
base_model_inference = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)

# CRITICAL: resize BEFORE loading the adapter
base_model_inference.resize_token_embeddings(len(tokenizer_inference))
print(f"  Base model vocab resized to {len(tokenizer_inference)}")

print("\nLoading LoRA adapter...")
model_inference = PeftModel.from_pretrained(base_model_inference, str(adapter_path))
model_inference.eval()

print("\n✓ Model loaded successfully — ready for inference.")

# Quick sanity check
sample_input = "Generate a Socratic question for this context: reasons_evidence: We should ban all fast food."
enc = tokenizer_inference(sample_input, return_tensors="pt")
with torch.no_grad():
    out = model_inference.generate(**enc, max_length=50, num_beams=4)
print(f"\nSanity check: {tokenizer_inference.decode(out[0], skip_special_tokens=True)}")


## TensorBoard

Training logs are written to `logs/{run_name}/`. To inspect the learning curve:

```bash
tensorboard --logdir ../logs
```

The adapter is saved to `models/flan-t5-socratic-lora/adapter/`; the merged model (adapter weights folded into the base model) is saved to `models/flan-t5-socratic-lora/merged/` for deployment contexts where PEFT is not available.