In [1]:
import os, glob, json, time, math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset

import ClassicalPredictor as CP
import TrajectoryGenerator as TG


# --------------------------
# Sharded memmap utilities
# --------------------------
import os, glob
import re

def discover_mpirank_shards(base_dir: str):
    shard_dirs = sorted(glob.glob(os.path.join(base_dir, "mpirank_*")))
    shard_dirs = [d for d in shard_dirs if os.path.isdir(d)]
    if not shard_dirs:
        raise FileNotFoundError(f"No mpirank_* dirs found under: {base_dir}")

    def rank_of(path):
        m = re.search(r"mpirank_(\d+)$", path)
        return int(m.group(1)) if m else 10**9

    shard_dirs.sort(key=rank_of)
    return shard_dirs

def shard_n_written(shard_dir: str) -> int:
    meta_path = os.path.join(shard_dir, "meta.npz")
    if not os.path.exists(meta_path):
        raise FileNotFoundError(f"Missing meta.npz in {shard_dir}")
    meta = np.load(meta_path)
    return int(meta["n_written"])

class ShardedMemmap:
    """
    Virtual concatenation over shards for a single .npy array.
    Uses numpy memmap to avoid loading into RAM.
    """
    def __init__(self, shard_dirs, filename: str, dtype=None):
        self.arrs = []
        self.cum = [0]
        self.shard_dirs = shard_dirs

        for sd in shard_dirs:
            n = shard_n_written(sd)
            path = os.path.join(sd, filename)
            if not os.path.exists(path):
                raise FileNotFoundError(f"Missing {path}")
            arr = np.load(path, mmap_mode="r")
            # keep only valid rows
            if arr.shape[0] < n:
                raise ValueError(f"{path} has {arr.shape[0]} rows but meta says n_written={n}")
            arr = arr[:n]
            if dtype is not None and arr.dtype != np.dtype(dtype):
                # memmap dtype conversion would copy; better to keep original
                pass
            self.arrs.append(arr)
            self.cum.append(self.cum[-1] + n)

        self.cum = np.array(self.cum, dtype=np.int64)

    def __len__(self):
        return int(self.cum[-1])

    def _locate(self, idx: int):
        # shard_id such that cum[shard_id] <= idx < cum[shard_id+1]
        sid = int(np.searchsorted(self.cum, idx, side="right") - 1)
        local = int(idx - self.cum[sid])
        return sid, local

    def __getitem__(self, idx: int):
        sid, local = self._locate(idx)
        return self.arrs[sid][local]


# --------------------------
# Split indices (90/5/5)
# --------------------------
def make_split_indices(N: int, seed: int = 0, frac_train=0.90, frac_val=0.05, frac_test=0.05):
    assert abs(frac_train + frac_val + frac_test - 1.0) < 1e-9
    rng = np.random.default_rng(seed)
    perm = rng.permutation(N)
    n_train = int(N * frac_train)
    n_val = int(N * frac_val)
    idx_train = perm[:n_train]
    idx_val   = perm[n_train:n_train+n_val]
    idx_test  = perm[n_train+n_val:]
    return idx_train, idx_val, idx_test


# --------------------------
# Minimal normalizers (fast & safe)
# --------------------------
class XChannelStandardizer:
    """
    log1p then per-channel standardize, for inputs shaped (T, C).
    """
    def __init__(self, log1p=True, eps=1e-6):
        self.log1p = log1p
        self.eps = eps
        self.mean = None  # (C,)
        self.std = None   # (C,)

    def fit(self, X_TxC: np.ndarray):
        # X_TxC stacked over many sequences: shape (N*T, C) or (N, T, C)
        if X_TxC.ndim == 3:
            X = X_TxC.reshape(-1, X_TxC.shape[-1])
        else:
            X = X_TxC
        if self.log1p:
            X = np.log1p(np.clip(X, 0.0, None))
        self.mean = X.mean(axis=0)
        self.std = X.std(axis=0)
        self.std = np.maximum(self.std, self.eps)

    def transform_one(self, x_TxC: np.ndarray) -> np.ndarray:
        x = x_TxC.astype(np.float32, copy=False)
        if self.log1p:
            x = np.log1p(np.clip(x, 0.0, None))
        return (x - self.mean) / self.std


class MaskedLogParamStandardizer:
    """
    For positive parameters: log(p) then per-dim standardize using only masked entries.
    """
    def __init__(self, eps=1e-12):
        self.eps = eps
        self.mean = None
        self.std = None

    def fit(self, Y: np.ndarray, M: np.ndarray):
        # Y: (N, D), M: (N, D) with 1 for used
        Y = Y.astype(np.float64)
        M = M.astype(np.float64)
        D = Y.shape[1]
        mean = np.zeros((D,), dtype=np.float64)
        var = np.ones((D,), dtype=np.float64)

        for d in range(D):
            mask = M[:, d] > 0.5
            if mask.any():
                vals = np.log(np.clip(Y[mask, d], self.eps, None))
                mean[d] = vals.mean()
                var[d] = vals.var() + 1e-6
            else:
                mean[d] = 0.0
                var[d] = 1.0
        self.mean = mean.astype(np.float32)
        self.std = np.sqrt(var).astype(np.float32)

    def transform_one(self, y: np.ndarray, m: np.ndarray) -> np.ndarray:
        y = y.astype(np.float32, copy=False)
        m = m.astype(np.float32, copy=False)
        out = np.zeros_like(y, dtype=np.float32)
        used = m > 0.5
        out[used] = (np.log(np.clip(y[used], self.eps, None)) - self.mean[used]) / self.std[used]
        # unused stay 0
        return out

    def inverse_one(self, y_norm: np.ndarray, m: np.ndarray) -> np.ndarray:
        # return params in original space (exp)
        y_norm = y_norm.astype(np.float32, copy=False)
        m = m.astype(np.float32, copy=False)
        out = np.zeros_like(y_norm, dtype=np.float32)
        used = m > 0.5
        out[used] = np.exp(y_norm[used] * self.std[used] + self.mean[used])
        return out


# --------------------------
# Dataset wrapping sharded arrays
# --------------------------
class ShardedMultiTaskSeqDataset(Dataset):
    """
    Returns:
      x: torch.float32 (T, C)
      y_cls: torch.long ()
      y_reg: torch.float32 (D,)
      y_mask: torch.float32 (D,)
    """
    def __init__(self, X_2xT: ShardedMemmap, y_cls: ShardedMemmap, y_reg: ShardedMemmap, y_mask: ShardedMemmap,
                 x_norm: XChannelStandardizer, y_norm: MaskedLogParamStandardizer):
        self.X_2xT = X_2xT
        self.y_cls = y_cls
        self.y_reg = y_reg
        self.y_mask = y_mask
        self.x_norm = x_norm
        self.y_norm = y_norm

    def __len__(self):
        return len(self.X_2xT)

    def __getitem__(self, idx):
        x_2xT = self.X_2xT[idx]                  # (2, T)
        x_Tx2 = np.transpose(x_2xT, (1, 0))      # (T, 2)
        x_Tx2 = self.x_norm.transform_one(x_Tx2)

        yc = int(self.y_cls[idx])
        yr = np.array(self.y_reg[idx], dtype=np.float32)
        ym = np.array(self.y_mask[idx], dtype=np.float32)
        yrn = self.y_norm.transform_one(yr, ym)

        return (
            torch.from_numpy(x_Tx2).float(),
            torch.tensor(yc, dtype=torch.long),
            torch.from_numpy(yrn).float(),
            torch.from_numpy(ym).float(),
        )


# --------------------------
# Helper: fit normalizers on a subset of TRAIN indices
# --------------------------
def fit_normalizers(dataset_kind: str,
                    shard_dirs,
                    train_idx: np.ndarray,
                    fit_n: int = 200_000,
                    seed: int = 0):
    rng = np.random.default_rng(seed)
    fit_n = min(fit_n, len(train_idx))
    sel = rng.choice(train_idx, size=fit_n, replace=False)

    if dataset_kind == "pk":
        X = ShardedMemmap(shard_dirs, "X_pk.npy")
        Y = ShardedMemmap(shard_dirs, "y_pk_reg.npy")
        M = ShardedMemmap(shard_dirs, "y_pk_mask.npy")
    else:
        X = ShardedMemmap(shard_dirs, "X_pd.npy")
        Y = ShardedMemmap(shard_dirs, "y_pd_reg.npy")
        M = ShardedMemmap(shard_dirs, "y_pd_mask.npy")

    # Fit X normalizer: stack (T,2) into (N*T,2) incrementally
    # (Avoid building an enormous array at once)
    xs = []
    for i in sel[:min(50_000, fit_n)]:  # pragmatic cap to keep RAM sane
        x_Tx2 = np.transpose(X[int(i)], (1, 0)).astype(np.float32, copy=False)
        xs.append(x_Tx2)
    xs = np.stack(xs, axis=0)  # (n, T, 2)
    x_norm = XChannelStandardizer(log1p=True)
    x_norm.fit(xs)

    # Fit Y normalizer
    ys = np.stack([np.array(Y[int(i)], dtype=np.float32) for i in sel[:min(200_000, fit_n)]], axis=0)
    ms = np.stack([np.array(M[int(i)], dtype=np.float32) for i in sel[:min(200_000, fit_n)]], axis=0)
    y_norm = MaskedLogParamStandardizer()
    y_norm.fit(ys, ms)

    return x_norm, y_norm



BASE = "./dataset"
shards = discover_mpirank_shards(BASE)
print("Found shards:", len(shards), "first/last:", shards[0], shards[-1])

# Load sharded arrays (memmaps)
X_pk = ShardedMemmap(shards, "X_pk.npy")
y_pk_cls = ShardedMemmap(shards, "y_pk_cls.npy")
y_pk_reg = ShardedMemmap(shards, "y_pk_reg.npy")
y_pk_mask = ShardedMemmap(shards, "y_pk_mask.npy")

X_pd = ShardedMemmap(shards, "X_pd.npy")
y_pd_cls = ShardedMemmap(shards, "y_pd_cls.npy")
y_pd_reg = ShardedMemmap(shards, "y_pd_reg.npy")
y_pd_mask = ShardedMemmap(shards, "y_pd_mask.npy")

N = len(X_pk)
print("Total samples:", N)




Found shards: 50 first/last: ./dataset/mpirank_0 ./dataset/mpirank_49
Total samples: 10000000


In [2]:
# Split indices once and reuse for PK+PD
idx_tr, idx_va, idx_te = make_split_indices(N, seed=123, frac_train=0.90, frac_val=0.05, frac_test=0.05)

# Save split indices for reproducibility
np.savez(os.path.join(BASE, "splits_90_5_5_seed123.npz"),
         idx_tr=idx_tr, idx_va=idx_va, idx_te=idx_te)

# Fit normalizers (train-only) for PK + PD
xnorm_pk, ynorm_pk = fit_normalizers("pk", shards, idx_tr, fit_n=200_000, seed=1)
xnorm_pd, ynorm_pd = fit_normalizers("pd", shards, idx_tr, fit_n=200_000, seed=2)

ds_pk = ShardedMultiTaskSeqDataset(X_pk, y_pk_cls, y_pk_reg, y_pk_mask, xnorm_pk, ynorm_pk)
ds_pd = ShardedMultiTaskSeqDataset(X_pd, y_pd_cls, y_pd_reg, y_pd_mask, xnorm_pd, ynorm_pd)


from pathlib import Path

def masked_mse(pred, target, mask):
    # pred/target: (B, D), mask: (B, D)
    diff2 = (pred - target) ** 2
    diff2 = diff2 * mask
    denom = mask.sum().clamp_min(1.0)
    return diff2.sum() / denom

@torch.no_grad()
def evaluate(model, loader, device, times_t, lambda_reg=1.0):
    model.eval()
    ce = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_correct = 0
    total_n = 0
    total_reg = 0.0

    for x, y_cls, y_reg, y_mask in loader:
        x = x.to(device, non_blocking=True)
        y_cls = y_cls.to(device, non_blocking=True)
        y_reg = y_reg.to(device, non_blocking=True)
        y_mask = y_mask.to(device, non_blocking=True)

        # IMPORTANT: pass times
        logits, yhat_reg = model(x, times_t)

        loss_cls = ce(logits, y_cls)
        loss_reg = masked_mse(yhat_reg, y_reg, y_mask)
        loss = loss_cls + lambda_reg * loss_reg

        total_loss += float(loss.item()) * x.size(0)
        total_reg  += float(loss_reg.item()) * x.size(0)

        pred = logits.argmax(dim=1)
        total_correct += int((pred == y_cls).sum().item())
        total_n += int(x.size(0))

    return {
        "loss": total_loss / max(total_n, 1),
        "acc": total_correct / max(total_n, 1),
        "reg_mse_masked": total_reg / max(total_n, 1),
        "n": total_n
    }


from pathlib import Path
import torch
import json
import math
import os
import torch.nn as nn

def train_one_config(
    run_name: str,
    dataset,
    idx_tr,
    idx_va,
    num_classes: int,
    reg_dim: int,
    device,
    out_dir: str,
    cfg: dict,
    times_t,
    resume_training: bool = False,
    resume_ckpt_path: str | None = None,
):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    ckpt_dir = out_dir / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    best_path = out_dir / f"{run_name}_best.pt"
    log_path  = out_dir / f"{run_name}_log.jsonl"

    def _as_list(x):
        return np.asarray(x, dtype=np.float32).tolist()

    # If dataset might be a Subset, unwrap it
    base_ds = dataset.dataset if hasattr(dataset, "dataset") else dataset
    
    x_norm_obj = getattr(base_ds, "x_norm", None)
    y_norm_obj = getattr(base_ds, "y_norm", None)
    # ----------------------------
    # DataLoaders (same as before)
    # ----------------------------
    dl_tr = DataLoader(
        Subset(dataset, idx_tr),
        batch_size=cfg["batch_size"],
        shuffle=True,
        num_workers=cfg.get("num_workers", 2),
        pin_memory=True,
        persistent_workers=(cfg.get("num_workers", 2) > 0),
    )
    dl_va = DataLoader(
        Subset(dataset, idx_va),
        batch_size=cfg.get("eval_batch_size", 512),
        shuffle=False,
        num_workers=cfg.get("num_workers", 2),
        pin_memory=True,
        persistent_workers=(cfg.get("num_workers", 2) > 0),
    )

    # ----------------------------
    # Model
    # ----------------------------
    model = CP.MultiTaskTransformer(
        input_dim=2,
        num_classes=num_classes,
        reg_dim=reg_dim,
        d_model=cfg["d_model"],
        nhead=cfg["nhead"],
        num_layers=cfg["num_layers"],
        dim_feedforward=cfg["dim_feedforward"],
        dropout=cfg["dropout"],
    ).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])

    # schedule
    steps_per_epoch = len(dl_tr)
    total_steps = steps_per_epoch * cfg["epochs"]
    warmup_steps = int(total_steps * cfg.get("warmup_frac", 0.05))

    def lr_lambda(step):
        if step < warmup_steps:
            return (step + 1) / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)

    ce = nn.CrossEntropyLoss(label_smoothing=cfg.get("label_smoothing", 0.0))
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.get("amp", True))

    save_every = int(cfg.get("save_every_epochs", 0))   # 0 disables
    patience = int(cfg.get("patience", 0))              # 0 disables
    min_delta = float(cfg.get("min_delta", 0.0))

    # ----------------------------
    # Resume state (NEW)
    # ----------------------------
    start_epoch = 0
    step = 0
    best = {"val_loss": float("inf")}
    best_epoch = -1
    bad_epochs = 0

    if resume_training:
        if resume_ckpt_path is None:
            # default: resume from best if exists
            resume_ckpt_path = str(best_path) if best_path.exists() else None

        if resume_ckpt_path is None or not os.path.exists(resume_ckpt_path):
            raise FileNotFoundError(f"resume_training=True but checkpoint not found: {resume_ckpt_path}")

        ckpt = torch.load(resume_ckpt_path, map_location=device)

        # restore weights + optimizer/scheduler/scaler
        model.load_state_dict(ckpt["model_state"])
        if "opt_state" in ckpt and ckpt["opt_state"] is not None:
            opt.load_state_dict(ckpt["opt_state"])
        if "sched_state" in ckpt and ckpt["sched_state"] is not None:
            sched.load_state_dict(ckpt["sched_state"])
        if "scaler_state" in ckpt and ckpt["scaler_state"] is not None and scaler is not None:
            scaler.load_state_dict(ckpt["scaler_state"])

        # restore counters
        start_epoch = int(ckpt.get("epoch", -1)) + 1
        step = int(ckpt.get("step", 0))

        # restore best tracking if present
        if "best_val" in ckpt and ckpt["best_val"] is not None:
            best = {"val_loss": ckpt["best_val"]["loss"], **ckpt["best_val"]}
            best_epoch = int(ckpt.get("epoch", -1))
            bad_epochs = 0  # reset; you can also load this if you saved it

        print(f"[{run_name}] Resuming from {resume_ckpt_path}")
        print(f"  start_epoch={start_epoch}, step={step}, best_val_loss={best['val_loss']:.6f}")

    # ----------------------------
    # Training loop (same as before, but start_epoch)
    # ----------------------------
    for epoch in range(start_epoch, cfg["epochs"]):
        model.train()
        for x, y_cls, y_reg, y_mask in dl_tr:
            x = x.to(device, non_blocking=True)
            y_cls = y_cls.to(device, non_blocking=True)
            y_reg = y_reg.to(device, non_blocking=True)
            y_mask = y_mask.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=cfg.get("amp", True)):
                logits, yhat_reg = model(x, times_t)
                loss_cls = ce(logits, y_cls)
                loss_reg = masked_mse(yhat_reg, y_reg, y_mask)
                loss = loss_cls + cfg["lambda_reg"] * loss_reg

            scaler.scale(loss).backward()
            if cfg.get("grad_clip", 0.0) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["grad_clip"])
            scaler.step(opt)
            scaler.update()
            sched.step()

            step += 1

        # ---- validate
        val_metrics = evaluate(model, dl_va, device, times_t, lambda_reg=cfg["lambda_reg"])

        # ---- log append
        rec = {"epoch": epoch, "step": step, "cfg": cfg, "val": val_metrics}
        with open(log_path, "a") as f:
            f.write(json.dumps(rec) + "\n")

        # ---- periodic checkpoint
        if save_every and ((epoch + 1) % save_every == 0):
            ckpt_path = ckpt_dir / f"{run_name}_epoch{epoch+1:04d}.pt"
            torch.save({
                "run_name": run_name,
                "epoch": epoch,
                "step": step,
                "cfg": cfg,
                "model_state": model.state_dict(),
                "opt_state": opt.state_dict(),
                "sched_state": sched.state_dict(),
                "scaler_state": scaler.state_dict() if scaler is not None else None,
                "val": val_metrics,
                "best_val": best if best["val_loss"] < float("inf") else None,
                "x_norm": (None if x_norm_obj is None else {
                    "mean": _as_list(x_norm_obj.mean),
                    "std":  _as_list(x_norm_obj.std),
                    "log1p": bool(getattr(x_norm_obj, "log1p", True)),
                }),
                "y_norm": (None if y_norm_obj is None else {
                    "mean": _as_list(y_norm_obj.mean),
                    "std":  _as_list(y_norm_obj.std),
            }),
            }, ckpt_path)

        # ---- best checkpoint + early stopping
        improved = (val_metrics["loss"] < best["val_loss"] - min_delta)
        if improved:
            best = {"val_loss": val_metrics["loss"], **val_metrics}
            best_epoch = epoch
            bad_epochs = 0
            torch.save({
                "run_name": run_name,
                "epoch": epoch,
                "step": step,
                "cfg": cfg,
                "model_state": model.state_dict(),
                "opt_state": opt.state_dict(),
                "sched_state": sched.state_dict(),
                "scaler_state": scaler.state_dict() if scaler is not None else None,
                "best_val": best,
                # Save normalizers if present (so inference later is consistent)
                "x_norm": (None if x_norm_obj is None else {
                    "mean": _as_list(x_norm_obj.mean),
                    "std":  _as_list(x_norm_obj.std),
                    "log1p": bool(getattr(x_norm_obj, "log1p", True)),
                }),
                "y_norm": (None if y_norm_obj is None else {
                    "mean": _as_list(y_norm_obj.mean),
                    "std":  _as_list(y_norm_obj.std),
                }),

            }, best_path)
        else:
            if patience:
                bad_epochs += 1
                if bad_epochs >= patience:
                    print(f"[{run_name}] Early stop at epoch={epoch+1} (best epoch={best_epoch+1}, best val_loss={best['val_loss']:.4f})")
                    break

        pct = 100.0 * (epoch + 1) / cfg["epochs"]
        print(f"[{run_name}] epoch={epoch+1}/{cfg['epochs']} ({pct:.1f}%) "
              f"val_loss={val_metrics['loss']:.4f} acc={val_metrics['acc']:.4f} reg={val_metrics['reg_mse_masked']:.4f}")

    return best, str(best_path)


def hyperparam_sweep(task_name: str,
                     dataset: Dataset,
                     idx_tr: np.ndarray,
                     idx_va: np.ndarray,
                     num_classes: int,
                     reg_dim: int,
                     device,
                     out_dir: str,
                     configs: list,
                     times_t):
    results = []
    for k, cfg in enumerate(configs):
        run_name = f"{task_name}_trial{k:03d}"
        best, ckpt_path = train_one_config(
            run_name=run_name,
            dataset=dataset,
            idx_tr=idx_tr,
            idx_va=idx_va,
            num_classes=num_classes,
            reg_dim=reg_dim,
            device=device,
            out_dir=out_dir,
            cfg=cfg,
            times_t=times_t
        )
        results.append({"run": run_name, "best": best, "ckpt": ckpt_path, "cfg": cfg})

    # write summary
    summary_path = os.path.join(out_dir, f"{task_name}_sweep_summary.json")
    with open(summary_path, "w") as f:
        json.dump(results, f, indent=2)
    return results

In [3]:
N_epochs = 3

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

configs = [
    dict(d_model=64,  nhead=4, num_layers=4, dim_feedforward=256, dropout=0.10,
         lr=3e-4, weight_decay=1e-2, lambda_reg=1.0, batch_size=256, epochs=N_epochs,
         warmup_frac=0.05, amp=True, grad_clip=1.0, num_workers=2),

    dict(d_model=128, nhead=8, num_layers=4, dim_feedforward=512, dropout=0.10,
         lr=3e-4, weight_decay=1e-2, lambda_reg=1.0, batch_size=256, epochs=N_epochs,
         warmup_frac=0.05, amp=True, grad_clip=1.0, num_workers=2),

    dict(d_model=128, nhead=8, num_layers=6, dim_feedforward=512, dropout=0.15,
         lr=2e-4, weight_decay=2e-2, lambda_reg=0.7, batch_size=256, epochs=N_epochs,
         warmup_frac=0.08, amp=True, grad_clip=1.0, num_workers=2),

    dict(d_model=192, nhead=8, num_layers=6, dim_feedforward=768, dropout=0.15,
         lr=2e-4, weight_decay=1e-2, lambda_reg=1.2, batch_size=192, epochs=N_epochs,
         warmup_frac=0.08, amp=True, grad_clip=1.0, num_workers=2),
]


In [5]:
pk_times_t = torch.tensor(TG.PK_TIMES, dtype=torch.float32, device=device)  # (39,)
pd_times_t = torch.tensor(TG.PD_TIMES, dtype=torch.float32, device=device)  # (25,)

In [6]:
OUT = "/pscratch/sd/b/by1997/pkpd/training_runs/sweep1"
os.makedirs(OUT, exist_ok=True)

# PK sweep
pk_results = hyperparam_sweep(
    task_name="PK",
    dataset=ds_pk,
    idx_tr=idx_tr,
    idx_va=idx_va,
    num_classes=10,
    reg_dim=TG.PK_PARAM_DIM,
    device=device,
    out_dir=os.path.join(OUT, "pk"),
    configs=configs,
    times_t=pk_times_t
)

# PD sweep
pd_results = hyperparam_sweep(
    task_name="PD",
    dataset=ds_pd,
    idx_tr=idx_tr,
    idx_va=idx_va,
    num_classes=10,
    reg_dim=TG.PD_PARAM_DIM,
    device=device,
    out_dir=os.path.join(OUT, "pd"),
    configs=configs,
    times_t=pd_times_t
)

print("Sweep finished.")




[PK_trial000] epoch=1/3 (33.3%) val_loss=1.6950 acc=0.5129 reg=0.5235
[PK_trial000] epoch=2/3 (66.7%) val_loss=1.5962 acc=0.5337 reg=0.4968
[PK_trial000] epoch=3/3 (100.0%) val_loss=1.5877 acc=0.5359 reg=0.4940




[PK_trial001] epoch=1/3 (33.3%) val_loss=1.5981 acc=0.5355 reg=0.4997
[PK_trial001] epoch=2/3 (66.7%) val_loss=1.5480 acc=0.5465 reg=0.4793
[PK_trial001] epoch=3/3 (100.0%) val_loss=1.5243 acc=0.5530 reg=0.4725
[PK_trial002] epoch=1/3 (33.3%) val_loss=1.4663 acc=0.5318 reg=0.5045
[PK_trial002] epoch=2/3 (66.7%) val_loss=1.4344 acc=0.5391 reg=0.4935
[PK_trial002] epoch=3/3 (100.0%) val_loss=1.4108 acc=0.5457 reg=0.4850
[PK_trial003] epoch=1/3 (33.3%) val_loss=1.7200 acc=0.5329 reg=0.5004
[PK_trial003] epoch=2/3 (66.7%) val_loss=1.6352 acc=0.5498 reg=0.4766
[PK_trial003] epoch=3/3 (100.0%) val_loss=1.6111 acc=0.5563 reg=0.4711
[PD_trial000] epoch=3/3 (100.0%) val_loss=2.9992 acc=0.1249 reg=0.7229
[PD_trial001] epoch=1/3 (33.3%) val_loss=3.0031 acc=0.1234 reg=0.7234
[PD_trial001] epoch=2/3 (66.7%) val_loss=2.9945 acc=0.1275 reg=0.7217
[PD_trial001] epoch=3/3 (100.0%) val_loss=2.9897 acc=0.1296 reg=0.7208
[PD_trial002] epoch=1/3 (33.3%) val_loss=2.7873 acc=0.1223 reg=0.7240
[PD_trial002] e

In [7]:
import json, os

summary_path = "/pscratch/sd/b/by1997/pkpd/training_runs/sweep1/pk/PK_sweep_summary.json"
with open(summary_path, "r") as f:
    results = json.load(f)

# sort by val loss
results_sorted = sorted(results, key=lambda r: r["best"]["loss"])
for r in results_sorted[:10]:
    print(r["run"], "val_loss:", r["best"]["loss"], "val_acc:", r["best"]["acc"], "reg_mse:", r["best"]["reg_mse_masked"])


PK_trial002 val_loss: 1.410815221710205 val_acc: 0.545748 reg_mse: 0.4849929124298096
PK_trial001 val_loss: 1.5243376971664429 val_acc: 0.55303 reg_mse: 0.47247732859611513
PK_trial000 val_loss: 1.5876962135772705 val_acc: 0.53591 reg_mse: 0.49399089193344115
PK_trial003 val_loss: 1.6110592265319825 val_acc: 0.556278 reg_mse: 0.4710990595493317


In [8]:
import json
import numpy as np

log_path = "/pscratch/sd/b/by1997/pkpd/training_runs/sweep1/pk/PK_trial000_log.jsonl"

epochs, vloss, vacc, vreg = [], [], [], []
with open(log_path, "r") as f:
    for line in f:
        rec = json.loads(line)
        epochs.append(rec["epoch"])
        vloss.append(rec["val"]["loss"])
        vacc.append(rec["val"]["acc"])
        vreg.append(rec["val"]["reg_mse_masked"])

print("min val_loss:", np.min(vloss), "at epoch", epochs[int(np.argmin(vloss))])
print("max val_acc:", np.max(vacc), "at epoch", epochs[int(np.argmax(vacc))])


min val_loss: 1.5876962135772705 at epoch 2
max val_acc: 0.53591 at epoch 2


In [9]:
def score(r, w_acc=1.0, w_reg=0.2):
    # higher is better: acc up, reg down
    return w_acc * r["best"]["acc"] - w_reg * r["best"]["reg_mse_masked"]

best_balanced = max(results, key=lambda r: score(r))
print("Best balanced:", best_balanced["run"], best_balanced["best"])


Best balanced: PK_trial003 {'val_loss': 1.6110592265319825, 'loss': 1.6110592265319825, 'acc': 0.556278, 'reg_mse_masked': 0.4710990595493317, 'n': 500000}


In [11]:
best_balanced

{'run': 'PK_trial003',
 'best': {'val_loss': 1.6110592265319825,
  'loss': 1.6110592265319825,
  'acc': 0.556278,
  'reg_mse_masked': 0.4710990595493317,
  'n': 500000},
 'ckpt': '/pscratch/sd/b/by1997/pkpd/training_runs/sweep1/pk/PK_trial003_best.pt',
 'cfg': {'d_model': 192,
  'nhead': 8,
  'num_layers': 6,
  'dim_feedforward': 768,
  'dropout': 0.15,
  'lr': 0.0002,
  'weight_decay': 0.01,
  'lambda_reg': 1.2,
  'batch_size': 192,
  'epochs': 3,
  'warmup_frac': 0.08,
  'amp': True,
  'grad_clip': 1.0,
  'num_workers': 2}}

In [12]:
import json
import os

def write_best_cfg_json(results, out_path: str):
    # results: list of dicts from hyperparam_sweep
    best_item = min(results, key=lambda r: r["best"]["val_loss"])
    cfg = dict(best_item["cfg"])  # copy

    # Optionally: override epochs for long training here
    # cfg["epochs"] = 400
    # cfg["save_every_epochs"] = 10

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w") as f:
        json.dump(cfg, f, indent=2)

    print("Wrote cfg to:", out_path)
    print("Best run:", best_item["run"])
    print("Best val:", best_item["best"])
    return cfg, best_item

In [13]:
cfg_pk, best_pk = write_best_cfg_json(
    pk_results,
    out_path="/pscratch/sd/b/by1997/pkpd/configs/cfg_final_pk.json"
)

cfg_pd, best_pd = write_best_cfg_json(
    pd_results,
    out_path="/pscratch/sd/b/by1997/pkpd/configs/cfg_final_pd.json"
)

Wrote cfg to: /pscratch/sd/b/by1997/pkpd/configs/cfg_final_pk.json
Best run: PK_trial002
Best val: {'val_loss': 1.410815221710205, 'loss': 1.410815221710205, 'acc': 0.545748, 'reg_mse_masked': 0.4849929124298096, 'n': 500000}
Wrote cfg to: /pscratch/sd/b/by1997/pkpd/configs/cfg_final_pd.json
Best run: PD_trial002
Best val: {'val_loss': 2.7805983345947265, 'loss': 2.7805983345947265, 'acc': 0.12658, 'reg_mse_masked': 0.7227154406700135, 'n': 500000}


In [14]:
cfg_pk["epochs"] = 150
cfg_pk["save_every_epochs"] = 10   # better than 50 if jobs die every 6h
cfg_pk["eval_batch_size"] = 1024
with open("/pscratch/sd/b/by1997/pkpd/configs/cfg_final_pk.json","w") as f:
    json.dump(cfg_pk, f, indent=2)

In [15]:
cfg_pd["epochs"] = 150
cfg_pd["save_every_epochs"] = 10   # better than 50 if jobs die every 6h
cfg_pd["eval_batch_size"] = 1024
with open("/pscratch/sd/b/by1997/pkpd/configs/cfg_final_pd.json","w") as f:
    json.dump(cfg_pd, f, indent=2)

In [10]:
import os, copy
import torch
import TrajectoryGenerator as TG

# Choose your best cfg from the sweep summary (or hardcode it)
cfg_long = copy.deepcopy(best_balanced["cfg"])   # or best_by_acc["cfg"] etc.

# Make it "long"
cfg_long["epochs"] = 300
cfg_long["save_every_epochs"] = 50   # periodic checkpoints
cfg_long["patience"] = 20            # early stop if no improvement for 20 epochs
cfg_long["min_delta"] = 1e-4
cfg_long["eval_batch_size"] = 1024   # faster validation
cfg_long["amp"] = True
cfg_long["grad_clip"] = 1.0

# (Optional) often helps for long runs: a slightly smaller LR
cfg_long["lr"] = cfg_long["lr"] * 0.7

OUT_LONG = "/pscratch/sd/b/by1997/pkpd/training_runs/long/pk"
os.makedirs(OUT_LONG, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pk_times_t = torch.tensor(TG.PK_TIMES, dtype=torch.float32, device=device)

best_long, best_ckpt_path = train_one_config(
    run_name="PK_best_long",
    dataset=ds_pk,
    idx_tr=idx_tr,
    idx_va=idx_va,
    num_classes=10,
    reg_dim=TG.PK_PARAM_DIM,
    device=device,
    out_dir=OUT_LONG,
    cfg=cfg_long,
    times_t=pk_times_t,
    resume_training=False,     # <-- from scratch
    resume_ckpt_path=None
)

print("Best val metrics:", best_long)
print("Best checkpoint:", best_ckpt_path)

KeyboardInterrupt: 

In [None]:
# run when we want to resume training
best, ckpt = train_one_config(
    run_name="PK_best_long",
    dataset=ds_pk,
    idx_tr=idx_tr,
    idx_va=idx_va,
    num_classes=10,
    reg_dim=TG.PK_PARAM_DIM,
    device=device,
    out_dir="/pscratch/sd/b/by1997/pkpd/training_runs/long/pk",
    cfg=cfg_long,
    times_t=pk_times_t,
    resume_training=True,
    resume_ckpt_path=None
)

In [None]:
from torch.utils.data import DataLoader, Subset

dl_te = DataLoader(
    Subset(ds_pk, idx_te),
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)

In [None]:
import torch

ckpt_path = best_balanced["ckpt"]   # or best_by_acc["ckpt"], etc.
ckpt = torch.load(ckpt_path, map_location=device)

cfg = ckpt["cfg"]
model = CP.MultiTaskTransformer(
    input_dim=2,
    num_classes=10,
    reg_dim=TG.PK_PARAM_DIM,
    d_model=cfg["d_model"],
    nhead=cfg["nhead"],
    num_layers=cfg["num_layers"],
    dim_feedforward=cfg["dim_feedforward"],
    dropout=cfg["dropout"],
).to(device)

model.load_state_dict(ckpt["model_state"])
model.eval()

pk_times_t = torch.tensor(TG.PK_TIMES, dtype=torch.float32, device=device)


In [None]:
@torch.no_grad()
def infer_on_loader(model, loader, device, times_t, y_norm_obj):
    all_pred_cls = []
    all_true_cls = []
    all_pred_params = []
    all_true_params = []
    all_masks = []

    for x, y_cls, y_reg_norm, y_mask in loader:
        x = x.to(device, non_blocking=True)
        y_cls = y_cls.to(device, non_blocking=True)
        y_reg_norm = y_reg_norm.to(device, non_blocking=True)
        y_mask = y_mask.to(device, non_blocking=True)

        logits, yhat_reg_norm = model(x, times_t)

        pred_cls = logits.argmax(dim=1)

        # move to CPU numpy
        pred_cls_np = pred_cls.cpu().numpy()
        true_cls_np = y_cls.cpu().numpy()
        mask_np = y_mask.cpu().numpy()

        # invert normalized reg targets + predictions into physical parameter values
        yhat_np = yhat_reg_norm.cpu().numpy()
        ytrue_np = y_reg_norm.cpu().numpy()

        pred_params = np.stack([y_norm_obj.inverse_one(yhat_np[i], mask_np[i]) for i in range(len(yhat_np))])
        true_params = np.stack([y_norm_obj.inverse_one(ytrue_np[i], mask_np[i]) for i in range(len(ytrue_np))])

        all_pred_cls.append(pred_cls_np)
        all_true_cls.append(true_cls_np)
        all_pred_params.append(pred_params)
        all_true_params.append(true_params)
        all_masks.append(mask_np)

    return {
        "pred_cls": np.concatenate(all_pred_cls),
        "true_cls": np.concatenate(all_true_cls),
        "pred_params": np.concatenate(all_pred_params, axis=0),
        "true_params": np.concatenate(all_true_params, axis=0),
        "mask": np.concatenate(all_masks, axis=0),
    }

pk_test_out = infer_on_loader(model, dl_te, device, pk_times_t, ynorm_pk)
print("pred_cls shape:", pk_test_out["pred_cls"].shape)
print("pred_params shape:", pk_test_out["pred_params"].shape)


In [None]:
from sklearn.metrics import accuracy_score

acc = accuracy_score(pk_test_out["true_cls"], pk_test_out["pred_cls"])
print("Test accuracy:", acc)

# masked MSE in physical space
diff2 = (pk_test_out["pred_params"] - pk_test_out["true_params"])**2
m = pk_test_out["mask"]
mse = (diff2 * m).sum() / np.maximum(m.sum(), 1.0)
print("Test masked param MSE (physical):", mse)
