In [5]:
import os, json
import numpy as np
import torch
import torch.nn.functional as F

import ClassicalPredictor as CP
import TrajectoryGenerator as TG


# -------------------------
# Normalizer reconstruction
# -------------------------
class XNorm:
    def __init__(self, mean, std, log1p=True, eps=1e-6):
        self.mean = np.asarray(mean, dtype=np.float32)
        self.std  = np.maximum(np.asarray(std, dtype=np.float32), eps)
        self.log1p = bool(log1p)

    def transform(self, X_NTxC: np.ndarray) -> np.ndarray:
        X = X_NTxC.astype(np.float32, copy=False)
        if self.log1p:
            X = np.log1p(np.clip(X, 0.0, None))
        return (X - self.mean[None, None, :]) / self.std[None, None, :]


class YNorm:
    def __init__(self, mean, std, eps=1e-12):
        self.mean = np.asarray(mean, dtype=np.float32)
        self.std  = np.maximum(np.asarray(std, dtype=np.float32), 1e-6)
        self.eps = eps

    def inverse(self, y_norm_NxD: np.ndarray, mask_NxD: np.ndarray) -> np.ndarray:
        y_norm_NxD = y_norm_NxD.astype(np.float32, copy=False)
        mask_NxD   = mask_NxD.astype(np.float32, copy=False)

        used = mask_NxD > 0.5
        z = y_norm_NxD * self.std[None, :] + self.mean[None, :]
        out = np.zeros_like(y_norm_NxD, dtype=np.float32)
        out[used] = np.exp(z[used])
        return out


# -------------------------
# Load ckpt + build model
# -------------------------
def load_ckpt_build_model(ckpt_path: str, device, num_classes: int, reg_dim: int):
    ckpt = torch.load(ckpt_path, map_location=device)
    cfg = ckpt["cfg"]

    model = CP.MultiTaskTransformer(
        input_dim=2,
        num_classes=num_classes,
        reg_dim=reg_dim,
        d_model=cfg["d_model"],
        nhead=cfg["nhead"],
        num_layers=cfg["num_layers"],
        dim_feedforward=cfg["dim_feedforward"],
        dropout=cfg["dropout"],
    ).to(device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    x_norm = None
    y_norm = None
    if ckpt.get("x_norm") is not None:
        x_norm = XNorm(
            mean=ckpt["x_norm"]["mean"],
            std=ckpt["x_norm"]["std"],
            log1p=ckpt["x_norm"].get("log1p", True),
        )
    if ckpt.get("y_norm") is not None:
        y_norm = YNorm(
            mean=ckpt["y_norm"]["mean"],
            std=ckpt["y_norm"]["std"],
        )

    return model, cfg, x_norm, y_norm


def ensure_times(times_NxT: np.ndarray, expected_T: int):
    if times_NxT.shape[1] != expected_T:
        raise ValueError(f"Expected T={expected_T} but got {times_NxT.shape[1]}")
    return times_NxT[0].astype(np.float32)


# -------------------------
# Step 1: predict classes
# -------------------------
def predict_class_probs(arr_Nx3xT, model, device, times_T, x_norm, batch_size=256):
    # features: (N,T,2)
    X = np.transpose(arr_Nx3xT[:, 1:3, :], (0, 2, 1)).astype(np.float32)
    if x_norm is not None:
        X = x_norm.transform(X)

    times_t = torch.tensor(times_T, dtype=torch.float32, device=device)

    N = X.shape[0]
    probs = []
    for s in range(0, N, batch_size):
        e = min(N, s + batch_size)
        xb = torch.from_numpy(X[s:e]).to(device)
        with torch.no_grad():
            logits, _ = model(xb, times_t)
            pb = F.softmax(logits, dim=1).cpu().numpy()
        probs.append(pb)
    return np.concatenate(probs, axis=0)


# -------------------------
# Step 2: force params under one class for ALL
# -------------------------
def force_class_params_for_all(arr_Nx3xT, model, device, times_T, x_norm, y_norm,
                               class_param_masks_KxD, forced_class_id: int, batch_size=256):
    X = np.transpose(arr_Nx3xT[:, 1:3, :], (0, 2, 1)).astype(np.float32)
    if x_norm is not None:
        X = x_norm.transform(X)

    times_t = torch.tensor(times_T, dtype=torch.float32, device=device)

    N = X.shape[0]
    D = class_param_masks_KxD.shape[1]
    forced_mask_D = class_param_masks_KxD[forced_class_id].astype(np.float32)  # (D,)
    forced_mask_NxD = np.tile(forced_mask_D[None, :], (N, 1)).astype(np.float32)

    yphys_all = np.zeros((N, D), dtype=np.float32)
    for s in range(0, N, batch_size):
        e = min(N, s + batch_size)
        xb = torch.from_numpy(X[s:e]).to(device)

        with torch.no_grad():
            logits, yhat_norm = model(xb, times_t)
            yhn = yhat_norm.cpu().numpy().astype(np.float32)  # (B,D)

        mask_be = forced_mask_NxD[s:e]  # (B,D)
        if y_norm is not None:
            yphys = y_norm.inverse(yhn, mask_be)
        else:
            yphys = yhn
        yphys_all[s:e] = yphys

    return {
        "forced_class_id": int(forced_class_id),
        "forced_mask": forced_mask_D,
        "forced_params_phys": yphys_all,
    }


def print_per_trajectory_params(task_name: str, forced_class_name: str,
                               param_keys, params_phys_NxD, mask_D, id_start: int):
    print(f"\n[{task_name}] Forced class = {forced_class_name}")
    used = mask_D > 0.5
    used_keys = [k for k,u in zip(param_keys, used) if u]

    for i in range(params_phys_NxD.shape[0]):
        traj_id = id_start + i
        parts = []
        for j, k in enumerate(param_keys):
            if used[j]:
                parts.append(f"{k}={params_phys_NxD[i,j]:.6g}")
        print(f"  ID {traj_id:>3d}: " + ", ".join(parts))


def summarize_ranges(param_keys, params_phys_NxD, mask_D, qs=(5,50,95)):
    used = mask_D > 0.5
    print("\nParameter ranges (percentiles) for forced class:")
    for j,k in enumerate(param_keys):
        if not used[j]:
            continue
        vals = params_phys_NxD[:, j]
        p = {q: float(np.percentile(vals, q)) for q in qs}
        print(f"  {k:>6s}: p5={p[5]:.6g}  p50={p[50]:.6g}  p95={p[95]:.6g}")


# =========================
# RUN EVERYTHING
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pk_arr = np.load("./final_inference_set/pkdata.npy")  # (36,3,39)
pd_arr = np.load("./final_inference_set/pddata.npy")  # (48,3,25)

pk_models = TG.build_pk_models(np.random.default_rng(0))
pd_models = TG.build_pd_models(np.random.default_rng(0))
pk_class_names = [m.name for m in pk_models]
pd_class_names = [m.name for m in pd_models]
pk_masks = np.stack([m.param_mask for m in pk_models], axis=0).astype(np.float32)
pd_masks = np.stack([m.param_mask for m in pd_models], axis=0).astype(np.float32)
pk_param_keys = TG.PK_PARAM_KEYS
pd_param_keys = TG.PD_PARAM_KEYS

# Load trained models
pk_model, _, pk_xn, pk_yn = load_ckpt_build_model(
    "./long_runs/pk_deep/PK_long_deep_best.pt", device, len(pk_class_names), len(pk_param_keys)
)
pd_model, _, pd_xn, pd_yn = load_ckpt_build_model(
    "./long_runs/pd_deep/PD_long_deep_best.pt", device, len(pd_class_names), len(pd_param_keys)
)

# Times from your trial arrays
pk_times = ensure_times(pk_arr[:,0,:], pk_arr.shape[2])
pd_times = ensure_times(pd_arr[:,0,:], pd_arr.shape[2])

# ---- Step 1: find most popular predicted class
pk_probs = predict_class_probs(pk_arr, pk_model, device, pk_times, pk_xn, batch_size=64)
pd_probs = predict_class_probs(pd_arr, pd_model, device, pd_times, pd_xn, batch_size=64)

pk_pred = pk_probs.argmax(axis=1)
pd_pred = pd_probs.argmax(axis=1)

pk_counts = np.bincount(pk_pred, minlength=len(pk_class_names))
pd_counts = np.bincount(pd_pred, minlength=len(pd_class_names))

pk_best_id = int(pk_counts.argmax())
pd_best_id = int(pd_counts.argmax())

print(f"[PK] most popular class = {pk_class_names[pk_best_id]} (id={pk_best_id}), counts={pk_counts.tolist()}")
print(f"[PD] most popular class = {pd_class_names[pd_best_id]} (id={pd_best_id}), counts={pd_counts.tolist()}")

# ---- Step 2: FORCE all trajectories to use that classâ€™s mask, and decode params
pk_forced = force_class_params_for_all(pk_arr, pk_model, device, pk_times, pk_xn, pk_yn, pk_masks, pk_best_id, batch_size=64)
pd_forced = force_class_params_for_all(pd_arr, pd_model, device, pd_times, pd_xn, pd_yn, pd_masks, pd_best_id, batch_size=64)

# ---- Step 3: print per-trajectory params
print_per_trajectory_params(
    task_name="PD",
    forced_class_name=pd_class_names[pd_best_id],
    param_keys=pd_param_keys,
    params_phys_NxD=pd_forced["forced_params_phys"],
    mask_D=pd_forced["forced_mask"],
    id_start=1
)

print_per_trajectory_params(
    task_name="PK",
    forced_class_name=pk_class_names[pk_best_id],
    param_keys=pk_param_keys,
    params_phys_NxD=pk_forced["forced_params_phys"],
    mask_D=pk_forced["forced_mask"],
    id_start=13
)

# ---- Step 4: print summary ranges
print("\n[PD] Summary ranges")
summarize_ranges(pd_param_keys, pd_forced["forced_params_phys"], pd_forced["forced_mask"])

print("\n[PK] Summary ranges")
summarize_ranges(pk_param_keys, pk_forced["forced_params_phys"], pk_forced["forced_mask"])


[PK] most popular class = pk_2c_iv_infusion (id=5), counts=[0, 0, 0, 0, 0, 36, 0, 0, 0, 0]
[PD] most popular class = pd_indirect_inhib_kin (id=5), counts=[1, 1, 3, 0, 2, 35, 3, 1, 1, 1]

[PD] Forced class = pd_indirect_inhib_kin
  ID   1: R0=14.5132, IC50=15.9307, kin=1.96493, kout=0.135555
  ID   2: R0=10.6281, IC50=19.0742, kin=1.45929, kout=0.137421
  ID   3: R0=5.65982, IC50=18.6728, kin=0.988865, kout=0.174128
  ID   4: R0=7.90742, IC50=17.4443, kin=1.10284, kout=0.139295
  ID   5: R0=7.82976, IC50=18.4151, kin=0.939613, kout=0.119909
  ID   6: R0=10.8586, IC50=17.2419, kin=1.45656, kout=0.134096
  ID   7: R0=10.0962, IC50=16.8776, kin=1.36391, kout=0.135386
  ID   8: R0=8.83373, IC50=16.3224, kin=1.03562, kout=0.117117
  ID   9: R0=9.11767, IC50=17.6436, kin=0.649201, kout=0.0710426
  ID  10: R0=7.01487, IC50=18.6815, kin=0.899225, kout=0.128112
  ID  11: R0=10.5496, IC50=17.0292, kin=1.31056, kout=0.124294
  ID  12: R0=6.12399, IC50=18.0029, kin=0.707646, kout=0.115445
  ID  13: