In [16]:
import os
import torch
from hydra import initialize, compose
from models.bcat import BCAT
from utils.muon import Muon
from omegaconf import OmegaConf

# ──────────────────────────────────────────────────────────────────────────────
# 1) cd into repo root so Hydra and imports work
# ──────────────────────────────────────────────────────────────────────────────
# os.chdir("/content/drive/MyDrive/bcat-main")

# ──────────────────────────────────────────────────────────────────────────────
# 2) Build the exact same Hydra config you used in training
# ──────────────────────────────────────────────────────────────────────────────
with initialize(config_path="configs", version_base=None):
    cfg = compose(
        config_name="main",
        overrides=[
            "data=fluids_arena",
            "data.incom_ns_arena_u.folder=/tmp/pde_h5_files",
            "optim=muon",
            "model.n_layer=12",
            "model.dim_emb=512",
            "model.dim_ffn=2048",
            "model.n_head=8",
            "model.kv_cache=false",
            "use_wandb=1",
            "batch_size=1",
            "n_steps_per_epoch=1000",
            "max_epoch=20",
        ],
    )

# ──────────────────────────────────────────────────────────────────────────────
# 3) Instantiate BCAT + Muon optimizer exactly as in training
# ──────────────────────────────────────────────────────────────────────────────
model = BCAT(
    cfg.model,
    x_num          = cfg.data.x_num,
    max_output_dim = cfg.data.max_output_dimension
)

# split parameters into Muon vs AdamW
named = [(n,p) for n,p in model.named_parameters() if p.requires_grad]
adam_keys    = ["embed"]
muon_params, adamw_params = [], []
for name, p in named:
    if p.ndim >= 2 and ".self_attn." in name and not any(k in name for k in adam_keys):
        muon_params.append(p)
    else:
        adamw_params.append(p)

optimizer = Muon(
    lr           = cfg.optim.lr,
    wd           = cfg.optim.weight_decay,
    muon_params  = muon_params,
    adamw_params = adamw_params,
    adamw_betas  = (0.9, cfg.optim.get("beta2", 0.95)),
    adamw_eps    = cfg.optim.get("eps", 1e-8),
)

# ──────────────────────────────────────────────────────────────────────────────
# 4) Load your checkpoint
# ──────────────────────────────────────────────────────────────────────────────
ckpt_path = "/content/drive/MyDrive/bcat-main/checkpoint/bcat/muon_ns_attention_52M/checkpoint.pth"
ckpt = torch.load(ckpt_path, map_location="cpu")

model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])

start_epoch  = ckpt.get("epoch", 0)
best_metrics = ckpt.get("best_metrics", {})
print(f"✅  Resumed from epoch {start_epoch} | best_metrics={best_metrics}")

MissingConfigException: Primary config directory not found.
Check that the config directory '/tmp/configs' exists and readable