# 02 — Training

Fine-tunes a T5-small model on topic-controlled question generation.

| Mode | Input format | Dataset | Paper name |
|------|-------------|---------|------------|
| `baseline` | `<topic> {t} <context> {text}` | SQuAD baseline | Baseline |
| `topic` | `<topic> {t} <context> {text}` | MixSQuAD (10k) | TopicQG |
| `topic2x` | `<topic> {t} <context> {text}` | MixSQuAD2X (20k) | TopicQG2X |

> Prerequisite: run `01_data_generation.ipynb` first (stages 1–4).

## Setup

In [None]:
import sys
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    REPO_URL = "https://github.com/YOUR_ORG/YOUR_REPO.git"  # TODO: set your URL
    !git clone {REPO_URL} /content/ai4ed-qg -q
    %cd /content/ai4ed-qg
    !pip install -q torch transformers datasets accelerate sentencepiece pyyaml tqdm

    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_DIR = Path('/content/drive/MyDrive/ai4ed_qg')

    # Restore data from Drive
    import shutil
    for subdir in ('processed', 'training'):
        src = DRIVE_DIR / subdir
        dst = Path('/content/ai4ed-qg/data') / subdir
        if src.exists():
            shutil.copytree(src, dst, dirs_exist_ok=True)
            print(f"Restored data/{subdir}/ from Drive")
else:
    DRIVE_DIR = None

import os
project_root = Path('/content/ai4ed-qg') if IN_COLAB else Path.cwd()
if project_root.name == 'notebooks':
    project_root = project_root.parent
os.chdir(project_root)
sys.path.insert(0, str(project_root))
print(f"Working dir: {os.getcwd()}")

In [None]:
import torch
print(f"PyTorch  : {torch.__version__}")
print(f"CUDA     : {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU      : {torch.cuda.get_device_name(0)}")
    mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"VRAM     : {mem:.1f} GB")
    if mem >= 35:
        print("Tip: A100 detected — you can increase batch size to 128")
    elif mem >= 15:
        print("Tip: V100/T4 — batch size 64 is fine")
    else:
        print("Tip: small GPU — reduce batch size if OOM")

## Initialise Pipeline

In [None]:
from src.pipeline import Pipeline

pipe = Pipeline('config/pipeline.yaml')
pipe.status()

## Configuration

Defaults are in `config/pipeline.yaml`. Override any value in the dict below.

In [None]:
# Override training config (optional)
# These mirror config/pipeline.yaml — edit pipeline.yaml to make permanent changes.
overrides = {
    # 'model_name': 'google-t5/t5-small',  # or 'google-t5/t5-base'
    # 'batch':       64,
    # 'lr':          1e-3,
    # 'epochs':      50,
    # 'max_input_len':  200,
    # 'max_output_len': 45,
}
for k, v in overrides.items():
    setattr(pipe.config.training, k, v)

t = pipe.config.training
print(f"Model      : {t.model_name}")
print(f"Epochs     : {t.epochs}")
print(f"Batch size : {t.batch}")
print(f"LR         : {t.lr}")
print(f"Max input  : {t.max_input_len} tokens")
print(f"Max output : {t.max_output_len} tokens")

## Train Models

Train one or all three modes. Each saves to `models/{mode}/best_model/`.

> Expected time on a T4 GPU: ~2–5 min/epoch × 50 epochs ≈ 2–4 hours per mode.

In [None]:
# ── TopicQG (main model — train this first) ───────────────────────────────────
model_path = pipe.train(mode='topic', dataset='squad')
print(f"\nSaved to: {model_path}")

In [None]:
# ── Baseline (context-only, no topic signal) ──────────────────────────────────
# model_path = pipe.train(mode='baseline', dataset='squad')
# print(f"Saved to: {model_path}")

In [None]:
# ── TopicQG2X (doubled dataset, context order reversed) ──────────────────────
# model_path = pipe.train(mode='topic2x', dataset='squad')
# print(f"Saved to: {model_path}")

## Test Generation

Quick sanity check — generate a question with the trained model.

In [None]:
topic   = "Photosynthesis"
context = (
    "Photosynthesis is a process used by plants and other organisms to convert "
    "light energy into chemical energy that can be stored and used. "
    "In plants, photosynthesis converts carbon dioxide and water into glucose "
    "and oxygen using energy from sunlight. Chlorophyll, the green pigment in "
    "leaves, is responsible for absorbing light energy."
)

question = pipe.generate(topic=topic, context=context, mode='topic')
print(f"Topic   : {topic}")
print(f"Question: {question}")

## Save Model to Drive (Colab)

In [None]:
if IN_COLAB and DRIVE_DIR:
    import shutil
    src = project_root / 'models'
    dst = DRIVE_DIR / 'models'
    dst.mkdir(parents=True, exist_ok=True)
    shutil.copytree(src, dst, dirs_exist_ok=True)
    print(f"Models synced to {dst}")

    # Also available as zip download
    shutil.make_archive('/content/models', 'zip', project_root / 'models')
    from google.colab import files
    files.download('/content/models.zip')

## Next Steps

Proceed to **`03_evaluation.ipynb`** to evaluate trained models against the paper's baselines.