# Curriculum Learning su GSM8K – Training sequenziale (Answer → Short-CoT → Full-CoT)
Questo notebook mostra una pipeline **a step sequenziali** per stabilizzare l’apprendimento di uno *student model* su GSM8K.

### Obiettivo operativo
1) **Stage A – Answer-only**: fissare numeracy + formato della risposta (massimizza EM, minimizza rumore).
2) **Stage B – Short-CoT**: introdurre ragionamento breve e strutturato senza far collassare l’accuratezza.
3) **Stage C – Full-CoT**: distillazione del ragionamento completo (se utile), mantenendo la risposta come priorità.

Alla fine trovi anche:
- **sanity check** per verificare coerenza dei dati
- **difficulty tagging** per creare il curriculum
- un **parser robusto** per EM
- uno scheletro di **weighted loss** (opzionale, ma spesso determinante).


In [None]:
from pathlib import Path
import sys

import json
from pathlib import Path

REPO_ROOT = Path('.').resolve().parents[1]

src_path = (REPO_ROOT / 'src').as_posix()
if src_path not in sys.path:
    sys.path.insert(0, src_path)

from models.student_model import StudentModel, StudentModelConfig

PROCESSED_DIR = REPO_ROOT / 'gsm8k-distillation' / 'data' / 'processed'
TRAIN_PATH = PROCESSED_DIR / 'gsm8k_train_processed.json'
TEST_PATH  = PROCESSED_DIR / 'gsm8k_test_processed.json'

OUTPUT_ROOT = REPO_ROOT / "outputs" / "curriculum_gsm8k"

print(TRAIN_PATH, TRAIN_PATH.exists())
print(TEST_PATH,  TEST_PATH.exists())

train_examples = StudentModel.load_processed_json(TRAIN_PATH)
test_examples  = StudentModel.load_processed_json(TEST_PATH)

In [None]:
easy   = [ex for ex in train_examples if (ex["difficulty_tag"] if isinstance(ex, dict) else ex.difficulty_tag) == "easy"]
medium = [ex for ex in train_examples if (ex["difficulty_tag"] if isinstance(ex, dict) else ex.difficulty_tag) == "medium"]
hard   = [ex for ex in train_examples if (ex["difficulty_tag"] if isinstance(ex, dict) else ex.difficulty_tag) == "hard"]

print("Counts:", len(easy), len(medium), len(hard))

CURRICULUM_PHASES = [
    {"name": "P1_easy",          "examples": easy, "epochs": 1, "lr": 3e-5},
    {"name": "P2_easy+medium",   "examples": easy + medium, "epochs": 1, "lr": 2e-5},
    {"name": "P3_all",           "examples": easy + medium + hard, "epochs": 2, "lr": 1e-5},
]

In [None]:
# Config modello
cfg = StudentModelConfig(model_name="google/flan-t5-large")
student = StudentModel(cfg)
print("Device:", student.device)

# Supervisione: 'answer' oppure 'cot' (se la tua build_hf_dataset lo supporta)
SUPERVISION = "answer"

# Tokenizza test una volta
eval_tok = student.build_hf_dataset(test_examples, supervision=SUPERVISION, limit=1000, shuffle=False)
print(eval_tok)

In [None]:
last_ckpt = None

for phase in CURRICULUM_PHASES:
    name = phase["name"]
    exs  = phase["examples"]
    epochs = phase["epochs"]
    lr     = phase["lr"]


    if len(exs) == 0:
        print(f"Skip {name}: empty")
        continue

    print(f"\n=== {name} | examples={len(exs)} | epochs={epochs} | lr={lr} ===")

    train_tok = student.build_hf_dataset(exs, supervision=SUPERVISION, limit=None, shuffle=True)

    out_dir = OUTPUT_ROOT / name
    if out_dir.exists() and last_ckpt is None:
        # se vuoi ripartire da zero, cancella cartella
        pass
    out_dir.mkdir(parents=True, exist_ok=True)

    trainer = student.train(
        train_dataset=train_tok,
        eval_dataset=eval_tok,
        output_dir=str(out_dir),
        num_train_epochs=epochs,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        learning_rate=lr,
        logging_steps=50,
        eval_steps=200,
        save_steps=400,
        resume_from_checkpoint=last_ckpt,   # StudentModel.train deve propagare a Trainer.train
    )

    # Trova ultimo checkpoint creato dalla fase
    ckpts = sorted([p for p in out_dir.glob("checkpoint-*")], key=lambda p: int(p.name.split("-")[-1]))
    last_ckpt = str(ckpts[-1]) if ckpts else None
    print("Last checkpoint:", last_ckpt)

In [None]:
# Evaluate Exact Match on raw examples (extracting last number from generation)
metrics = student.evaluate_exact_match(test_examples, num_beams=4, limit=500)
print('EM:', metrics['exact_match'], 'n:', metrics['n'])

print('\nTop-5 debug rows (question, generation, pred, gold):')
for q, gen, pred, gold in metrics['pred_examples']:
    print('\nQ:', q[:220])
    print('GEN:', gen[:220])
    print('PRED:', pred, 'GOLD:', gold)