In [1]:
import os, json, glob, re
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

import TrajectoryGenerator as TG
import ClassicalPredictor as CP
import data_and_training as DT  # your big module (has ShardedMemmap, fit_normalizers, evaluate helpers)


# ----------------------------
# Helpers
# ----------------------------
def load_cfg_and_model(ckpt_path, num_classes, reg_dim, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    cfg = ckpt["cfg"]

    model = CP.MultiTaskTransformer(
        input_dim=2,
        num_classes=num_classes,
        reg_dim=reg_dim,
        d_model=cfg["d_model"],
        nhead=cfg["nhead"],
        num_layers=cfg["num_layers"],
        dim_feedforward=cfg["dim_feedforward"],
        dropout=cfg["dropout"],
    ).to(device)

    model.load_state_dict(ckpt["model_state"])
    model.eval()
    return model, cfg, ckpt


@torch.no_grad()
def masked_mse_norm(yhat, ytrue, mask):
    # yhat,ytrue,mask: torch tensors (B,D)
    diff2 = (yhat - ytrue) ** 2
    diff2 = diff2 * mask
    denom = mask.sum().clamp_min(1.0)
    return float(diff2.sum().item() / denom.item())


@torch.no_grad()
def infer_on_loader(model, loader, device, times_t, y_norm_obj):
    all_pred_cls = []
    all_true_cls = []
    all_pred_params = []
    all_true_params = []
    all_masks = []

    for x, y_cls, y_reg_norm, y_mask in loader:
        x = x.to(device, non_blocking=True)
        y_cls = y_cls.to(device, non_blocking=True)
        y_reg_norm = y_reg_norm.to(device, non_blocking=True)
        y_mask = y_mask.to(device, non_blocking=True)

        logits, yhat_reg_norm = model(x, times_t)
        pred_cls = logits.argmax(dim=1)

        # move to cpu numpy
        pred_cls_np = pred_cls.cpu().numpy()
        true_cls_np = y_cls.cpu().numpy()
        mask_np = y_mask.cpu().numpy()

        yhat_np = yhat_reg_norm.cpu().numpy()
        ytrue_np = y_reg_norm.cpu().numpy()

        # invert normalized -> physical using mask
        pred_params = np.stack([y_norm_obj.inverse_one(yhat_np[i], mask_np[i]) for i in range(len(yhat_np))])
        true_params = np.stack([y_norm_obj.inverse_one(ytrue_np[i], mask_np[i]) for i in range(len(ytrue_np))])

        all_pred_cls.append(pred_cls_np)
        all_true_cls.append(true_cls_np)
        all_pred_params.append(pred_params)
        all_true_params.append(true_params)
        all_masks.append(mask_np)

    return {
        "pred_cls": np.concatenate(all_pred_cls),
        "true_cls": np.concatenate(all_true_cls),
        "pred_params": np.concatenate(all_pred_params, axis=0),
        "true_params": np.concatenate(all_true_params, axis=0),
        "mask": np.concatenate(all_masks, axis=0),
    }


def evaluate_test(model, ds, idx_te, device, times_t, name=""):
    dl_te = DataLoader(
        Subset(ds, idx_te.tolist()),
        batch_size=1024,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    # Compute accuracy + normalized masked MSE directly from model outputs/targets
    total = 0
    correct = 0
    reg_mse_sum = 0.0
    batches = 0

    ce = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for x, y_cls, y_reg_norm, y_mask in dl_te:
            x = x.to(device, non_blocking=True)
            y_cls = y_cls.to(device, non_blocking=True)
            y_reg_norm = y_reg_norm.to(device, non_blocking=True)
            y_mask = y_mask.to(device, non_blocking=True)

            logits, yhat_reg_norm = model(x, times_t)
            pred = logits.argmax(dim=1)

            correct += int((pred == y_cls).sum().item())
            total += int(x.size(0))

            reg_mse_sum += masked_mse_norm(yhat_reg_norm, y_reg_norm, y_mask) * x.size(0)
            batches += 1

    acc = correct / max(total, 1)
    reg_mse_norm_avg = reg_mse_sum / max(total, 1)

    print(f"\n[{name}] TEST (normalized space)")
    print("  n:", total)
    print("  acc:", acc)
    print("  masked_reg_mse_norm:", reg_mse_norm_avg)

    return dl_te, {"acc": acc, "masked_reg_mse_norm": reg_mse_norm_avg}


# ----------------------------
# Main
# ----------------------------
def main():
    BASE_DIR = "./dataset"
    SPLITS_NPZ = "./dataset/splits_90_5_5_seed123.npz"

    # Your deep-run best checkpoints:
    PK_CKPT = "./long_runs/pk_deep/PK_long_deep_best.pt"
    PD_CKPT = "./long_runs/pd_deep/PD_long_deep_best.pt"

    # Normalizer refit settings (must match what you used in training)
    FIT_N = 200_000
    PK_NORM_SEED = 1
    PD_NORM_SEED = 2

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)

    # Discover shards
    shards = DT.discover_mpirank_shards(BASE_DIR)

    # Load splits
    spl = np.load(SPLITS_NPZ)
    idx_tr = spl["idx_tr"].astype(np.int64)
    idx_va = spl["idx_va"].astype(np.int64)
    idx_te = spl["idx_te"].astype(np.int64)

    print("Split sizes:", len(idx_tr), len(idx_va), len(idx_te))

    # =========================
    # PK: build memmaps + norms + dataset + model
    # =========================
    X_pk = DT.ShardedMemmap(shards, "X_pk.npy")
    y_pk_cls = DT.ShardedMemmap(shards, "y_pk_cls.npy")
    y_pk_reg = DT.ShardedMemmap(shards, "y_pk_reg.npy")
    y_pk_mask = DT.ShardedMemmap(shards, "y_pk_mask.npy")

    xnorm_pk, ynorm_pk = DT.fit_normalizers(
        dataset_kind="pk", shard_dirs=shards, train_idx=idx_tr,
        fit_n=FIT_N, seed=PK_NORM_SEED
    )
    ds_pk = DT.ShardedMultiTaskSeqDataset(X_pk, y_pk_cls, y_pk_reg, y_pk_mask, xnorm_pk, ynorm_pk)

    pk_times_t = torch.tensor(TG.PK_TIMES, dtype=torch.float32, device=device)

    pk_model, pk_cfg, _ = load_cfg_and_model(PK_CKPT, num_classes=10, reg_dim=TG.PK_PARAM_DIM, device=device)

    dl_pk_te, pk_metrics = evaluate_test(pk_model, ds_pk, idx_te, device, pk_times_t, name="PK_deep")

    pk_out = infer_on_loader(pk_model, dl_pk_te, device, pk_times_t, ynorm_pk)

    pk_acc = (pk_out["pred_cls"] == pk_out["true_cls"]).mean()
    diff2 = (pk_out["pred_params"] - pk_out["true_params"]) ** 2
    m = pk_out["mask"]
    pk_mse_phys = float((diff2 * m).sum() / max(m.sum(), 1.0))

    print("\n[PK_deep] TEST (physical space)")
    print("  acc:", float(pk_acc))
    print("  masked_param_mse_phys:", pk_mse_phys)

    # =========================
    # PD: build memmaps + norms + dataset + model
    # =========================
    X_pd = DT.ShardedMemmap(shards, "X_pd.npy")
    y_pd_cls = DT.ShardedMemmap(shards, "y_pd_cls.npy")
    y_pd_reg = DT.ShardedMemmap(shards, "y_pd_reg.npy")
    y_pd_mask = DT.ShardedMemmap(shards, "y_pd_mask.npy")

    xnorm_pd, ynorm_pd = DT.fit_normalizers(
        dataset_kind="pd", shard_dirs=shards, train_idx=idx_tr,
        fit_n=FIT_N, seed=PD_NORM_SEED
    )
    ds_pd = DT.ShardedMultiTaskSeqDataset(X_pd, y_pd_cls, y_pd_reg, y_pd_mask, xnorm_pd, ynorm_pd)

    pd_times_t = torch.tensor(TG.PD_TIMES, dtype=torch.float32, device=device)

    pd_model, pd_cfg, _ = load_cfg_and_model(PD_CKPT, num_classes=10, reg_dim=TG.PD_PARAM_DIM, device=device)

    dl_pd_te, pd_metrics = evaluate_test(pd_model, ds_pd, idx_te, device, pd_times_t, name="PD_deep")

    pd_out = infer_on_loader(pd_model, dl_pd_te, device, pd_times_t, ynorm_pd)

    pd_acc = (pd_out["pred_cls"] == pd_out["true_cls"]).mean()
    diff2 = (pd_out["pred_params"] - pd_out["true_params"]) ** 2
    m = pd_out["mask"]
    pd_mse_phys = float((diff2 * m).sum() / max(m.sum(), 1.0))

    print("\n[PD_deep] TEST (physical space)")
    print("  acc:", float(pd_acc))
    print("  masked_param_mse_phys:", pd_mse_phys)

    # =========================
    # Optional: save predictions
    # =========================
    os.makedirs("./test_outputs", exist_ok=True)
    np.savez("./test_outputs/pk_deep_test_preds.npz", **pk_out)
    np.savez("./test_outputs/pd_deep_test_preds.npz", **pd_out)
    print("\nSaved outputs to ./test_outputs/*.npz")


if __name__ == "__main__":
    main()

device: cuda
Split sizes: 9000000 500000 500000





[PK_deep] TEST (normalized space)
  n: 500000
  acc: 0.958718
  masked_reg_mse_norm: 0.31668178060120855

[PK_deep] TEST (physical space)
  acc: 0.958718
  masked_param_mse_phys: 1.9237614870071411

[PD_deep] TEST (normalized space)
  n: 500000
  acc: 0.714184
  masked_reg_mse_norm: 0.5159172073498169

[PD_deep] TEST (physical space)
  acc: 0.714184
  masked_param_mse_phys: 11.98989486694336

Saved outputs to ./test_outputs/*.npz
