# 02 — Test StudentModel (tiny fine-tune + EM eval)

This notebook smoke-tests the `StudentModel` wrapper:
- loads processed GSM8K JSON
- builds tokenized datasets
- runs a short training loop (few steps)
- evaluates Exact Match on extracted final answer

Tip: start with a small subset to validate the end-to-end flow, then scale.


In [None]:
from pathlib import Path
import sys

REPO_ROOT = Path('.').resolve().parents[1]

src_path = (REPO_ROOT / 'src').as_posix()
if src_path not in sys.path:
    sys.path.insert(0, src_path)

from models.student_model import StudentModel, StudentModelConfig

PROCESSED_DIR = REPO_ROOT / 'gsm8k-distillation' / 'data' / 'processed'
train_path = PROCESSED_DIR / 'gsm8k_train_processed.json'
test_path  = PROCESSED_DIR / 'gsm8k_test_processed.json'
print(train_path, train_path.exists())
print(test_path,  test_path.exists())


/Users/marcobonvissuto/Desktop/Università/Magistrale/Secondo Anno/gsm8k-distillation/data/processed/gsm8k_train_processed.json True
/Users/marcobonvissuto/Desktop/Università/Magistrale/Secondo Anno/gsm8k-distillation/data/processed/gsm8k_test_processed.json True


In [13]:
# Load processed examples
train_examples = StudentModel.load_processed_json(train_path)
test_examples  = StudentModel.load_processed_json(test_path)
print('train:', len(train_examples), 'test:', len(test_examples))
print('sample keys:', train_examples[0].keys())


train: 7473 test: 1319
sample keys: dict_keys(['question', 'reasoning', 'answer', 'split', 'index'])


In [14]:
# Configure student
cfg = StudentModelConfig(model_name='google/flan-t5-large')
student = StudentModel(cfg)

print('Device:', student.device)


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

Cancellation requested; stopping current tasks.


KeyboardInterrupt: 

In [None]:
# Build datasets: choose supervision='cot' (distillation) or 'answer' (baseline)
supervision = 'cot'

# Smoke-test with small subsets
train_tok = student.build_hf_dataset(train_examples, supervision=supervision, limit=7000, shuffle=True)
eval_tok  = student.build_hf_dataset(test_examples,  supervision=supervision, limit=500, shuffle=False)

print(train_tok)
print(eval_tok)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]



Map:   0%|          | 0/300 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2000
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 300
})


In [None]:
# Train briefly (adjust for your GPU/CPU)
trainer = student.train(
    train_dataset=train_tok,
    eval_dataset=eval_tok,
    output_dir='outputs/student_smoke',
    num_train_epochs=6,
    learning_rate=2e-5,
    logging_steps=25,
    eval_steps=200,
    save_steps=600,
    predict_with_generate=False
)

  trainer = Seq2SeqTrainer(


Step,Training Loss,Validation Loss
50,0.9547,0.792637


In [None]:
# Evaluate Exact Match on raw examples (extracting last number from generation)
metrics = student.evaluate_exact_match(test_examples, num_beams=4, limit=500, )
print('EM:', metrics['exact_match'], 'n:', metrics['n'])

print('\nTop-5 debug rows (question, generation, pred, gold):')
for q, gen, pred, gold in metrics['pred_examples']:
    print('\nQ:', q[:220])
    print('GEN:', gen[:220])
    print('PRED:', pred, 'GOLD:', gold)


In [None]:
# Save the fine-tuned model
student.save('outputs/student_smoke/final')
print('Saved to outputs/student_smoke/final')
