# Evaluate training set metrics for saved checkpoints

This notebook loads model checkpoints from the `checkpoints` directory and evaluates them on the same dataset that was used for training. The metrics are printed for each checkpoint so you can verify that the model was learning during training.

In [None]:
from pathlib import Path
from collections import OrderedDict

import torch
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq

from ttt.dataloader import DatasetByPrompt
from ttt.options import DataArguments
from ttt.utils import compute_metrics


In [None]:
# Path to the experiment directory containing checkpoint subfolders
run_dir = Path("/home/parsa/Codebases/GitHub_Repositories/swarm-distillation-zero-shot/checkpoints/anli_r1/11B_ttt_t0.train.source.validation.anli.none.dev_r1.T0pp.peft.lora.lora_alpha4.lora_drop0.3.bn1.pw1.0.np5.bsz1.ga4.lr2e-5.steps.1000_20250708")

# Cache directory used during training (adjust if needed)
cache_dir = run_dir.parent.parent.parent / 'pretrain_models' / 'huggingface'


In [None]:
# List all checkpoints sorted by step number
checkpoints = sorted([p for p in run_dir.glob('checkpoint-*') if p.is_dir()],
                        key=lambda p: int(p.name.split('-')[-1]))
checkpoints

In [None]:
# Load the dataset used for training (ANLI R1 validation split)
data_args = DataArguments(
    dataset_name='anli',
    prompt_set_name='anli',
    subset_name='none',
    testset_name='dev_r1'
)
tokenizer = AutoTokenizer.from_pretrained(checkpoints[0])
train_ds = DatasetByPrompt(data_args, str(cache_dir), tokenizer)
collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=-100)
metric = load_metric('accuracy')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
def evaluate_checkpoint(ckpt_path):
    model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_path).to(device)
    model.eval()

    all_data = []
    golds = []
    for idx in range(len(train_ds)):
        example, label = train_ds[idx]
        all_data.extend(example)
        golds.append(label)

    logprobs = []
    batch_size = 8
    for i in range(0, len(all_data), batch_size):
        batch = collator(all_data[i:i+batch_size])
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            loss = model(**batch).loss
        logprobs.extend(loss.detach().cpu().tolist())

    results, _ = compute_metrics(
        logprobs, len(train_ds), train_ds.num_choices, train_ds.num_prompts,
        golds, metric
    )
    return results


In [None]:
metrics_by_step = OrderedDict()
for ckpt in checkpoints:
    step = int(ckpt.name.split('-')[-1])
    metrics_by_step[step] = evaluate_checkpoint(ckpt)

metrics_by_step
