# JEPA Model Inspection

Inspect a JEPA pretraining config and checkpoint: hyperparameters, parameter counts, and checkpoint metadata.

In [None]:
from __future__ import annotations

import json
import os
from pathlib import Path

import pandas as pd
import torch
import copy


def find_project_root(start: Path) -> Path:
    p = start.resolve()
    for candidate in [p, *p.parents]:
        if (candidate / "src").exists() and (candidate / "configs").exists():
            return candidate
    raise RuntimeError("Could not locate project root containing src/ and configs/")


PROJECT_ROOT = find_project_root(Path.cwd())
print(f"Project root: {PROJECT_ROOT}")

import sys
if str(PROJECT_ROOT / "src") not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT / "src"))

from models.jepa.jepa import JEPA
from models.time_series.patchTransformer import PatchTSTEncoder


Project root: C:\python\koulu\Gradu


In [1]:
# -----------------------------
# User parameters
# -----------------------------
JEPA_CONFIG_PATH = "configs/jepa_pretrain6.json"

# Optional checkpoint override. Set to None to auto-pick best.pt then latest epoch*.pt.
JEPA_CHECKPOINT_PATH = None

# If True, print first N state_dict keys
SHOW_STATE_KEYS = False
STATE_KEYS_N = 40


In [2]:
def resolve_project_path(path_value: str | Path) -> Path:
    p = Path(path_value)
    if p.is_absolute():
        return p
    return (PROJECT_ROOT / p).resolve()


def load_json(path: Path) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def pick_checkpoint(checkpoint_dir: Path) -> Path | None:
    best_path = checkpoint_dir / "best.pt"
    if best_path.exists():
        return best_path

    epoch_ckpts = sorted(checkpoint_dir.glob("epoch*.pt"), key=lambda p: p.stat().st_mtime, reverse=True)
    return epoch_ckpts[0] if epoch_ckpts else None


cfg_path = resolve_project_path(JEPA_CONFIG_PATH)
cfg = load_json(cfg_path)

model_name = cfg["model_name"]
paths_cfg = cfg.get("paths", {})
jepa_cfg = cfg["jepa_model"]
train_cfg = cfg.get("training", {})
dataset_cfg = cfg.get("dataset", {})
loss_cfg = cfg.get("loss", {})

checkpoint_root = resolve_project_path(paths_cfg.get("checkpoint_root", "checkpoints"))
checkpoint_dir = checkpoint_root / model_name

if JEPA_CHECKPOINT_PATH is not None:
    ckpt_path = resolve_project_path(JEPA_CHECKPOINT_PATH)
else:
    ckpt_path = pick_checkpoint(checkpoint_dir)

print(f"Config: {cfg_path}")
print(f"Model name: {model_name}")
print(f"Checkpoint dir: {checkpoint_dir}")
print(f"Checkpoint selected: {ckpt_path}")


NameError: name 'Path' is not defined

In [None]:
# Build JEPA model from config

# Backward-compatible key handling
use_asset_embeddings = jepa_cfg.get("use_asset_embeddings", jepa_cfg.get("use_asset_embedding", True))
num_assets = None  # dataset-dependent; omitted for architecture inspection by default

context_enc = PatchTSTEncoder(
    patch_len=jepa_cfg["patch_len"],
    d_model=jepa_cfg["d_model"],
    n_features=jepa_cfg["n_features"],
    n_time_features=jepa_cfg["n_time_features"],
    nhead=jepa_cfg["nhead"],
    num_layers=jepa_cfg["num_layers"],
    dim_ff=jepa_cfg["dim_ff"],
    dropout=jepa_cfg["dropout"],
    add_cls=jepa_cfg.get("add_cls", True),
    pooling=jepa_cfg["pooling"],
    pred_len=jepa_cfg["pred_len"],
    num_assets=num_assets if use_asset_embeddings else None,
)

target_enc = copy.deepcopy(context_enc)

model = JEPA(
    context_enc,
    target_enc,
    d_model=jepa_cfg["d_model"],
    ema_tau_min=jepa_cfg["ema_tau_min"],
    ema_tau_max=jepa_cfg["ema_tau_max"],
)

print(model.__class__.__name__)


JEPA




In [None]:
def count_params(module: torch.nn.Module) -> tuple[int, int]:
    total = sum(p.numel() for p in module.parameters())
    trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
    return total, trainable


def format_m(n: int) -> str:
    return f"{n:,} ({n/1e6:.3f}M)"


rows = []
for name, module in [
    ("context_enc", model.context_enc),
    ("target_enc", model.target_enc),
    ("predictor", model.predictor),
    ("full_model", model),
]:
    total, trainable = count_params(module)
    rows.append({
        "module": name,
        "params_total": total,
        "params_trainable": trainable,
        "params_total_fmt": format_m(total),
        "params_trainable_fmt": format_m(trainable),
    })

param_df = pd.DataFrame(rows)
param_df


Unnamed: 0,module,params_total,params_trainable,params_total_fmt,params_trainable_fmt
0,context_enc,1800385,1800385,"1,800,385 (1.800M)","1,800,385 (1.800M)"
1,target_enc,1800385,0,"1,800,385 (1.800M)",0 (0.000M)
2,predictor,296256,296256,"296,256 (0.296M)","296,256 (0.296M)"
3,full_model,3897026,2096641,"3,897,026 (3.897M)","2,096,641 (2.097M)"


In [None]:
hyper_rows = [
    ("model_name", model_name),
    ("patch_len", jepa_cfg.get("patch_len")),
    ("patch_stride", jepa_cfg.get("patch_stride")),
    ("d_model", jepa_cfg.get("d_model")),
    ("n_features", jepa_cfg.get("n_features")),
    ("n_time_features", jepa_cfg.get("n_time_features")),
    ("nhead", jepa_cfg.get("nhead")),
    ("num_layers", jepa_cfg.get("num_layers")),
    ("dim_ff", jepa_cfg.get("dim_ff")),
    ("dropout", jepa_cfg.get("dropout")),
    ("add_cls", jepa_cfg.get("add_cls")),
    ("pooling", jepa_cfg.get("pooling")),
    ("pred_len", jepa_cfg.get("pred_len")),
    ("ema_tau_min", jepa_cfg.get("ema_tau_min")),
    ("ema_tau_max", jepa_cfg.get("ema_tau_max")),
    ("use_asset_embeddings", use_asset_embeddings),
    ("train_epochs", train_cfg.get("epochs")),
    ("train_batch_size", train_cfg.get("batch_size_train")),
    ("learning_rate", train_cfg.get("learning_rate")),
    ("dataset_root_path", dataset_cfg.get("root_path")),
    ("dataset_data_path", dataset_cfg.get("data_path")),
    ("dataset_timeframe", dataset_cfg.get("timeframe")),
    ("loss_type", loss_cfg.get("loss_type")),
]

hyper_df = pd.DataFrame(hyper_rows, columns=["key", "value"])
hyper_df


Unnamed: 0,key,value
0,model_name,jepa_initial6
1,patch_len,8
2,patch_stride,8
3,d_model,192
4,n_features,9
5,n_time_features,2
6,nhead,4
7,num_layers,4
8,dim_ff,768
9,dropout,0.0


In [None]:
checkpoint_info = {
    "checkpoint_exists": False,
    "checkpoint_path": None,
    "checkpoint_epoch": None,
    "checkpoint_monitor": None,
    "asset_universe_size": None,
    "state_dict_keys": None,
    "missing_keys": None,
    "unexpected_keys": None,
}

if ckpt_path is not None and ckpt_path.exists():
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt.get("model", ckpt.get("model_state_dict", ckpt))

    missing, unexpected = model.load_state_dict(state, strict=False)

    asset_universe = ckpt.get("asset_universe")
    checkpoint_info.update({
        "checkpoint_exists": True,
        "checkpoint_path": str(ckpt_path),
        "checkpoint_epoch": ckpt.get("epoch"),
        "checkpoint_monitor": ckpt.get("monitor"),
        "asset_universe_size": len(asset_universe) if asset_universe is not None else None,
        "state_dict_keys": len(state.keys()) if isinstance(state, dict) else None,
        "missing_keys": len(missing),
        "unexpected_keys": len(unexpected),
    })

    if SHOW_STATE_KEYS and isinstance(state, dict):
        print("First state_dict keys:")
        for k in list(state.keys())[:STATE_KEYS_N]:
            print(" -", k)

checkpoint_df = pd.DataFrame(list(checkpoint_info.items()), columns=["key", "value"])
checkpoint_df


In [None]:
# Useful quick checks
asset_emb_present = hasattr(model.context_enc, "asset_emb") and model.context_enc.asset_emb is not None
print(f"Context encoder has asset embedding module: {asset_emb_present}")

if ckpt_path is not None and ckpt_path.exists():
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt.get("model", ckpt.get("model_state_dict", ckpt))
    asset_related = [k for k in state.keys() if "asset_emb" in k or "asset_gate" in k]
    print(f"Asset-related checkpoint tensors: {len(asset_related)}")
    if asset_related:
        for k in asset_related[:20]:
            t = state[k]
            shape = tuple(t.shape) if hasattr(t, "shape") else None
            print(f" - {k}: {shape}")
