In [2]:
import os
import numpy as np
import torch

import TrajectoryGenerator as TG
import data_and_training as DT  # your big file


# -----------------------------
# Paths
# -----------------------------
BASE_DIR   = "./dataset"  # contains mpirank_*
SPLITS_NPZ = "./dataset/splits_90_5_5_seed123.npz"
OUT_DIR    = "./sweeps/pd_deeper128_L8_5ep"
os.makedirs(OUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# -----------------------------
# Load splits
# -----------------------------
spl = np.load(SPLITS_NPZ)
idx_tr = spl["idx_tr"].astype(np.int64)
idx_va = spl["idx_va"].astype(np.int64)


# -----------------------------
# Discover shards + build PD memmaps
# -----------------------------
shards = DT.discover_mpirank_shards(BASE_DIR)

X_pd      = DT.ShardedMemmap(shards, "X_pd.npy")
y_pd_cls  = DT.ShardedMemmap(shards, "y_pd_cls.npy")
y_pd_reg  = DT.ShardedMemmap(shards, "y_pd_reg.npy")
y_pd_mask = DT.ShardedMemmap(shards, "y_pd_mask.npy")

times_t = torch.tensor(TG.PD_TIMES, dtype=torch.float32).to(device)


# -----------------------------
# Fit normalizers (train-only, deterministic)
# -----------------------------
xnorm_pd, ynorm_pd = DT.fit_normalizers(
    dataset_kind="pd",
    shard_dirs=shards,
    train_idx=idx_tr,
    fit_n=200_000,
    seed=2,   # PD seed=2 (match what you used)
)

ds_pd = DT.ShardedMultiTaskSeqDataset(
    X_pd, y_pd_cls, y_pd_reg, y_pd_mask,
    xnorm_pd, ynorm_pd
)


# -----------------------------
# Deeper-128 config (num_layers=8) for 5 epochs
# -----------------------------
cfg_deeper128 = dict(
    d_model=128, nhead=8, num_layers=8, dim_feedforward=512, dropout=0.10,
    lr=2e-4, weight_decay=1e-2, lambda_reg=1.0,
    batch_size=256, epochs=5,
    warmup_frac=0.08, amp=True, grad_clip=1.0, num_workers=2,
    eval_batch_size=1024,
    save_every_epochs=0,  # off for quick diagnostic
    patience=0
)

pd_results = DT.hyperparam_sweep(
    task_name="PD_deeper128_L8_5ep",
    dataset=ds_pd,
    idx_tr=idx_tr.tolist(),
    idx_va=idx_va.tolist(),
    num_classes=10,
    reg_dim=TG.PD_PARAM_DIM,
    device=device,
    out_dir=OUT_DIR,
    configs=[cfg_deeper128],
    times_t=times_t
)

print("Done. PD deeper-128 (L=8) diagnostic.")
print("Best:", pd_results[0]["best"])
print("Ckpt:", pd_results[0]["ckpt"])




[PD_deeper128_L8_5ep_trial000] epoch=1/5 (20.0%) val_loss=3.0041 acc=0.1220 reg=0.7243
[PD_deeper128_L8_5ep_trial000] epoch=2/5 (40.0%) val_loss=2.9966 acc=0.1256 reg=0.7220
[PD_deeper128_L8_5ep_trial000] epoch=3/5 (60.0%) val_loss=2.9667 acc=0.1404 reg=0.7170
[PD_deeper128_L8_5ep_trial000] epoch=4/5 (80.0%) val_loss=2.9352 acc=0.1526 reg=0.7127
[PD_deeper128_L8_5ep_trial000] epoch=5/5 (100.0%) val_loss=2.9262 acc=0.1559 reg=0.7114
Done. PD deeper-128 (L=8) diagnostic.
Best: {'val_loss': 2.9261925871582033, 'loss': 2.9261925871582033, 'acc': 0.155928, 'reg_mse_masked': 0.7113775536880493, 'n': 500000}
Ckpt: sweeps/pd_deeper128_L8_5ep/PD_deeper128_L8_5ep_trial000_best.pt
