In [3]:
import os, json
import numpy as np
import torch
import torch.nn.functional as F

import ClassicalPredictor as CP
import TrajectoryGenerator as TG


# -------------------------
# Normalizer reconstruction
# -------------------------
class XNorm:
    def __init__(self, mean, std, log1p=True, eps=1e-6):
        self.mean = np.asarray(mean, dtype=np.float32)
        self.std  = np.maximum(np.asarray(std, dtype=np.float32), eps)
        self.log1p = bool(log1p)

    def transform(self, X_NTxC: np.ndarray) -> np.ndarray:
        # X: (N,T,C)
        X = X_NTxC.astype(np.float32, copy=False)
        if self.log1p:
            X = np.log1p(np.clip(X, 0.0, None))
        return (X - self.mean[None, None, :]) / self.std[None, None, :]


class YNorm:
    """
    In training you did:
      y_norm[used] = (log(y_phys[used]) - mean[used]) / std[used]
    So inverse is:
      y_phys = exp(y_norm * std + mean)  (then zero-out unused)
    """
    def __init__(self, mean, std):
        self.mean = np.asarray(mean, dtype=np.float32)
        self.std  = np.maximum(np.asarray(std, dtype=np.float32), 1e-6)

    def inverse(self, y_norm_NxD: np.ndarray, mask_NxD: np.ndarray) -> np.ndarray:
        y_norm_NxD = np.asarray(y_norm_NxD, dtype=np.float32)
        mask_NxD   = np.asarray(mask_NxD, dtype=np.float32)
        # compute full exp then mask out unused dims
        y_phys = np.exp(y_norm_NxD * self.std[None, :] + self.mean[None, :]).astype(np.float32)
        y_phys *= mask_NxD
        return y_phys


# -------------------------
# Helpers
# -------------------------
def ensure_times(times_NxT: np.ndarray, expected_T: int) -> np.ndarray:
    if times_NxT.shape[1] != expected_T:
        raise ValueError(f"Expected T={expected_T}, got {times_NxT.shape[1]}")
    return times_NxT[0].astype(np.float32)  # assume common grid


def load_ckpt_build_model(ckpt_path: str, device, num_classes: int, reg_dim: int):
    ckpt = torch.load(ckpt_path, map_location=device)
    cfg = ckpt["cfg"]

    model = CP.MultiTaskTransformer(
        input_dim=2,
        num_classes=num_classes,
        reg_dim=reg_dim,
        d_model=cfg["d_model"],
        nhead=cfg["nhead"],
        num_layers=cfg["num_layers"],
        dim_feedforward=cfg["dim_feedforward"],
        dropout=cfg["dropout"],
    ).to(device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    # norms (strongly recommended; otherwise accuracy will drop)
    if ("x_norm" not in ckpt) or (ckpt["x_norm"] is None):
        raise RuntimeError(f"{ckpt_path} does not contain x_norm. Re-save ckpts with x_norm/y_norm.")
    if ("y_norm" not in ckpt) or (ckpt["y_norm"] is None):
        raise RuntimeError(f"{ckpt_path} does not contain y_norm. Re-save ckpts with x_norm/y_norm.")

    x_norm = XNorm(
        mean=np.array(ckpt["x_norm"]["mean"]),
        std=np.array(ckpt["x_norm"]["std"]),
        log1p=ckpt["x_norm"].get("log1p", True),
    )
    y_norm = YNorm(
        mean=np.array(ckpt["y_norm"]["mean"]),
        std=np.array(ckpt["y_norm"]["std"]),
    )

    return model, cfg, x_norm, y_norm


@torch.no_grad()
def run_logits_and_regnorm(
    arr_Nx3xT: np.ndarray,
    model,
    device,
    times_T: np.ndarray,
    x_norm: XNorm,
    batch_size: int = 256,
):
    """
    Returns:
      probs (N,K), pred_cls (N,), yhat_norm (N,D)
    """
    assert arr_Nx3xT.ndim == 3 and arr_Nx3xT.shape[1] == 3
    N, _, T = arr_Nx3xT.shape

    # features: (dose, conc) or (conc, biom) => (N,T,2)
    X = np.transpose(arr_Nx3xT[:, 1:3, :], (0, 2, 1)).astype(np.float32)  # (N,T,2)
    X = x_norm.transform(X)

    times_t = torch.tensor(times_T, dtype=torch.float32, device=device)

    probs_all = []
    cls_all = []
    yhat_all = []

    for s in range(0, N, batch_size):
        e = min(N, s + batch_size)
        xb = torch.from_numpy(X[s:e]).to(device, non_blocking=True)
        logits, yhat_norm = model(xb, times_t)
        probs = F.softmax(logits, dim=1).cpu().numpy().astype(np.float32)
        pred_cls = probs.argmax(axis=1).astype(np.int64)

        probs_all.append(probs)
        cls_all.append(pred_cls)
        yhat_all.append(yhat_norm.cpu().numpy().astype(np.float32))

    return (
        np.concatenate(probs_all, axis=0),
        np.concatenate(cls_all, axis=0),
        np.concatenate(yhat_all, axis=0),
    )


def summarize_ranges(params_NxD: np.ndarray, mask_D: np.ndarray, keys: list, qs=(5, 50, 95)):
    out = {}
    used = mask_D > 0.5
    for j, k in enumerate(keys):
        if not used[j]:
            continue
        vals = params_NxD[:, j]
        out[k] = {f"p{q}": float(np.percentile(vals, q)) for q in qs}
    return out


def explain_best_class_for_all(
    probs_NxK: np.ndarray,
    pred_cls_N: np.ndarray,
    yhat_norm_NxD: np.ndarray,
    y_norm: YNorm,
    class_names: list[str],
    class_masks_KxD: np.ndarray,
    param_keys: list[str],
):
    """
    1) most popular predicted class
    2) treat ALL trajectories as that class => invert params using that class mask
    3) summarize parameter ranges
    """
    K = len(class_names)
    counts = np.bincount(pred_cls_N, minlength=K)
    best_k = int(counts.argmax())
    best_name = class_names[best_k]
    best_mask = class_masks_KxD[best_k].astype(np.float32)

    # apply best_mask to all
    mask_all = np.repeat(best_mask[None, :], repeats=yhat_norm_NxD.shape[0], axis=0)
    params_phys_all = y_norm.inverse(yhat_norm_NxD, mask_all)

    summary = {
        "most_popular_class": best_name,
        "most_popular_class_id": best_k,
        "class_counts": {class_names[i]: int(counts[i]) for i in range(K) if counts[i] > 0},
        "param_ranges_p5_p50_p95": summarize_ranges(params_phys_all, best_mask, param_keys, qs=(5,50,95)),
    }
    return summary, params_phys_all, best_k


def run_trial_explain(
    pk_npy: str,
    pd_npy: str,
    pk_ckpt: str,
    pd_ckpt: str,
    out_json: str = "./final_inference_set/pkpd_bestclass_summary.json",
    out_npz: str  = "./final_inference_set/pkpd_bestclass_params.npz",
    batch_size: int = 64,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    pk_arr = np.load(pk_npy)  # (N,3,39)
    pd_arr = np.load(pd_npy)  # (N,3,25)

    pk_times = ensure_times(pk_arr[:, 0, :], pk_arr.shape[2])
    pd_times = ensure_times(pd_arr[:, 0, :], pd_arr.shape[2])

    # class libraries + masks
    pk_models = TG.build_pk_models(np.random.default_rng(0))
    pd_models = TG.build_pd_models(np.random.default_rng(0))
    pk_names = [m.name for m in pk_models]
    pd_names = [m.name for m in pd_models]
    pk_masks = np.stack([m.param_mask for m in pk_models], axis=0).astype(np.float32)
    pd_masks = np.stack([m.param_mask for m in pd_models], axis=0).astype(np.float32)

    pk_param_keys = TG.PK_PARAM_KEYS
    pd_param_keys = TG.PD_PARAM_KEYS

    # load models + norms
    pk_model, _, pk_xnorm, pk_ynorm = load_ckpt_build_model(pk_ckpt, device, len(pk_names), len(pk_param_keys))
    pd_model, _, pd_xnorm, pd_ynorm = load_ckpt_build_model(pd_ckpt, device, len(pd_names), len(pd_param_keys))

    # run inference (logits + regnorm)
    pk_probs, pk_pred_cls, pk_yhat_norm = run_logits_and_regnorm(pk_arr, pk_model, device, pk_times, pk_xnorm, batch_size)
    pd_probs, pd_pred_cls, pd_yhat_norm = run_logits_and_regnorm(pd_arr, pd_model, device, pd_times, pd_xnorm, batch_size)

    # explain “best class for all”
    pk_summary, pk_params_all, pk_best_k = explain_best_class_for_all(
        pk_probs, pk_pred_cls, pk_yhat_norm, pk_ynorm, pk_names, pk_masks, pk_param_keys
    )
    pd_summary, pd_params_all, pd_best_k = explain_best_class_for_all(
        pd_probs, pd_pred_cls, pd_yhat_norm, pd_ynorm, pd_names, pd_masks, pd_param_keys
    )

    # write outputs
    os.makedirs(os.path.dirname(out_json), exist_ok=True)
    with open(out_json, "w") as f:
        json.dump({"pk": pk_summary, "pd": pd_summary}, f, indent=2)

    np.savez(
        out_npz,
        pk_pred_cls=pk_pred_cls,
        pk_probs=pk_probs,
        pk_best_class_id=pk_best_k,
        pk_bestclass_params_phys=pk_params_all,  # (N, Dpk), masked dims kept, others 0
        pd_pred_cls=pd_pred_cls,
        pd_probs=pd_probs,
        pd_best_class_id=pd_best_k,
        pd_bestclass_params_phys=pd_params_all,  # (N, Dpd), masked dims kept, others 0
    )

    print("Wrote:", out_json)
    print("Wrote:", out_npz)

    print("\n[PK] most popular class:", pk_summary["most_popular_class"], "counts:", pk_summary["class_counts"])
    print("[PK] param ranges:", pk_summary["param_ranges_p5_p50_p95"])

    print("\n[PD] most popular class:", pd_summary["most_popular_class"], "counts:", pd_summary["class_counts"])
    print("[PD] param ranges:", pd_summary["param_ranges_p5_p50_p95"])

    return {"pk": pk_summary, "pd": pd_summary}

In [4]:
def print_summary_block(kind: str, summary: dict):
    print("\n" + "=" * 80)
    print(f"{kind.upper()} SUMMARY")
    print("=" * 80)

    print(f"Most popular class: {summary['most_popular_class']} (id={summary['most_popular_class_id']})")
    print("\nClass counts:")
    for k, v in sorted(summary["class_counts"].items(), key=lambda kv: -kv[1]):
        print(f"  {k:>24s} : {v}")

    print("\nParameter ranges (p5 / p50 / p95) for MOST popular class:")
    pr = summary["param_ranges_p5_p50_p95"]
    if not pr:
        print("  (no parameters active for this class?)")
    else:
        for p_name, qs in pr.items():
            p5  = qs.get("p5", None)
            p50 = qs.get("p50", None)
            p95 = qs.get("p95", None)
            print(f"  {p_name:>10s} : {p5:.6g} / {p50:.6g} / {p95:.6g}")




In [5]:
pk_npy_path = "./final_inference_set/pkdata.npy"
pd_npy_path = "./final_inference_set/pddata.npy"
pk_ckpt_path = "./long_runs/pk_deep/PK_long_deep_best.pt"
pd_ckpt_path = "./long_runs/pd_deep/PD_long_deep_best.pt"


# ---- after you run run_trial_explain(...) ----
out = run_trial_explain(
    pk_npy=pk_npy_path,
    pd_npy=pd_npy_path,
    pk_ckpt=pk_ckpt_path,
    pd_ckpt=pd_ckpt_path,
    out_json="./final_inference_set/pkpd_bestclass_summary.json",
    out_npz="./final_inference_set/pkpd_bestclass_params.npz",
    batch_size=64,
)

print_summary_block("pk", out["pk"])
print_summary_block("pd", out["pd"])

Wrote: ./final_inference_set/pkpd_bestclass_summary.json
Wrote: ./final_inference_set/pkpd_bestclass_params.npz

[PK] most popular class: pk_2c_iv_infusion counts: {'pk_2c_iv_infusion': 36}
[PK] param ranges: {'CL': {'p5': 0.02954680845141411, 'p50': 0.048646895214915276, 'p95': 0.0873776227235794}, 'Vc': {'p5': 1.1349785327911377, 'p50': 1.3215688467025757, 'p95': 1.543166697025299}, 'Vp': {'p5': 6.172096490859985, 'p50': 8.03603458404541, 'p95': 10.4442298412323}, 'Q': {'p5': 0.1753370277583599, 'p50': 0.33222708106040955, 'p95': 0.7330343127250671}, 'tau': {'p5': 3.435567319393158, 'p50': 5.0334038734436035, 'p95': 6.821564316749573}}

[PD] most popular class: pd_indirect_inhib_kin counts: {'pd_direct_linear': 1, 'pd_direct_emax': 1, 'pd_direct_sigmoid': 3, 'pd_effect_comp_emax': 2, 'pd_indirect_inhib_kin': 35, 'pd_indirect_stim_kin': 3, 'pd_indirect_inhib_kout': 1, 'pd_indirect_stim_kout': 1, 'pd_transit_delay': 1}
[PD] param ranges: {'R0': {'p5': 5.926348447799683, 'p50': 8.757164