In [None]:
import os
import json
from glob import glob
import torch
import numpy as np

#@title Enter Paths
DATASET_PATH = "data/lisp-simplification" #@param {type:"string"}
CHECKPOINT_PATH = "checkpoints/Lisp-simplification ACT-torch/.../step_..." #@param {type:"string"}
PAD_TOKEN_ID = 0

In [None]:
def load_vocab_and_preds(dataset_path: str, checkpoint_path: str):
    with open(os.path.join(dataset_path, "vocab.json"), "r") as f:
        vocab = json.load(f)
        rev_vocab = {v: k for k, v in vocab.items()}

    all_preds = {}
    for filename in glob(f"{checkpoint_path}_all_preds.*"):
        preds = torch.load(filename)
        for k, v in preds.items():
            all_preds.setdefault(k, [])
            all_preds[k].append(v)
        del preds

    all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}
    return rev_vocab, all_preds

def decode_sequence(seq: np.ndarray, rev_vocab: dict):
    tokens = []
    for token_id in seq:
        if token_id == PAD_TOKEN_ID:
            break
        tokens.append(rev_vocab.get(token_id, '?'))
    return ' '.join(tokens)

def test(num_samples_to_show=10):
    rev_vocab, all_preds = load_vocab_and_preds(DATASET_PATH, CHECKPOINT_PATH)
    
    inputs = all_preds["inputs"].numpy()
    labels = all_preds["labels"].numpy()
    preds = all_preds["logits"].argmax(-1).numpy()

    correct_count = 0
    total_count = len(inputs)

    print(f"--- Evaluating {total_count} samples ---\n")

    for i in range(total_count):
        input_str = decode_sequence(inputs[i], rev_vocab)
        label_str = decode_sequence(labels[i], rev_vocab)
        pred_str = decode_sequence(preds[i], rev_vocab)
        
        is_correct = (label_str == pred_str)
        if is_correct:
            correct_count += 1

        if i < num_samples_to_show:
            print(f"Sample {i}:")
            print(f"  Input:    {input_str}")
            print(f"  Expected: {label_str}")
            print(f"  Got:      {pred_str}")
            print(f"  Result:   {'Correct' if is_correct else 'Incorrect'}\n")
    
    accuracy = correct_count / total_count * 100
    print(f"\n--- Results ---")
    print(f"Exact match accuracy: {accuracy:.2f}% ({correct_count}/{total_count})")

test()