# PPO Model Testing

Evaluate a trained PPO+JEPA model using a single training config file.


In [1]:
from __future__ import annotations

import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
import matplotlib.pyplot as plt
import copy

import numpy as np
import pandas as pd
import torch

# Resolve project root robustly when notebook is launched from different cwd
def find_project_root(start: Path) -> Path:
    p = start.resolve()
    for candidate in [p, *p.parents]:
        if (candidate / "src").exists() and (candidate / "configs").exists():
            return candidate
    raise RuntimeError("Could not locate project root containing src/ and configs/")

PROJECT_ROOT = find_project_root(Path.cwd())
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
if str(PROJECT_ROOT / "src") not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT / "src"))

from config.config_utils import load_json_config
from Datasets.multi_asset_dataset import Dataset_Finance_MultiAsset
from Training.sb3_jepa_ppo import JEPAAuxFeatureExtractor, PPOWithJEPA
from models.jepa.jepa import JEPA
from models.time_series.patchTransformer import PatchTSTEncoder

print(f"Project root: {PROJECT_ROOT}")


Project root: C:\python\koulu\Gradu


In [2]:
# -----------------------------
# User parameters
# -----------------------------
# Config used for training this PPO model
PPO_CONFIG_PATH = "configs/ppo_jepa_train1.json"

# Optional overrides (set to None to auto-resolve from config)
PPO_CHECKPOINT_PATH = None
JEPA_CHECKPOINT_PATH = "checkpoints/jepa6_ppo1/jepa_step_700000.pt"

# If None, evaluate all assets available in validation split
MAX_ASSETS = None

# Deterministic policy during evaluation
DETERMINISTIC = True


In [3]:
def get_latest_ppo_checkpoint(checkpoint_dir: str) -> str | None:
    if not os.path.isdir(checkpoint_dir):
        return None
    ckpts = []
    for fname in os.listdir(checkpoint_dir):
        if fname.startswith("ppo_") and fname.endswith("_steps.zip"):
            ckpts.append(os.path.join(checkpoint_dir, fname))
    if not ckpts:
        return None
    ckpts.sort(key=lambda p: os.path.getmtime(p))
    return ckpts[-1]


def load_tickers(path: str) -> list | None:
    if not path or not os.path.exists(path):
        return None
    with open(path, "r", encoding="utf-8") as f:
        tickers = [line.strip() for line in f if line.strip()]
    return tickers or None


def load_asset_universe_from_checkpoint(path: str | None) -> list | None:
    if not path or not os.path.exists(path):
        return None
    try:
        checkpoint = torch.load(path, map_location="cpu")
    except Exception:
        return None
    asset_universe = checkpoint.get("asset_universe")
    return list(asset_universe) if asset_universe else None


cfg = load_json_config(str(PROJECT_ROOT / PPO_CONFIG_PATH), "", str(PROJECT_ROOT / "notebooks" / "ppo_model_test.ipynb"))

model_name = cfg["model_name"]
paths_cfg = cfg["paths"]
dataset_cfg = cfg["dataset"]
env_cfg = cfg["env"]
ppo_cfg = cfg["ppo"]
jepa_cfg = cfg["jepa_model"]

checkpoint_root = paths_cfg.get("checkpoint_root", "checkpoints")
log_root = paths_cfg.get("log_root", "logs")
ppo_checkpoint_dir = str(PROJECT_ROOT / checkpoint_root / model_name)

if PPO_CHECKPOINT_PATH is None:
    PPO_CHECKPOINT_PATH = get_latest_ppo_checkpoint(ppo_checkpoint_dir)
if PPO_CHECKPOINT_PATH is None:
    raise FileNotFoundError(f"No PPO checkpoint found under {ppo_checkpoint_dir}")
if not os.path.isabs(PPO_CHECKPOINT_PATH):
    PPO_CHECKPOINT_PATH = str(PROJECT_ROOT / PPO_CHECKPOINT_PATH)

if JEPA_CHECKPOINT_PATH is None:
    jepa_checkpoint_dir = paths_cfg["jepa_checkpoint_dir"]
    JEPA_CHECKPOINT_PATH = str(PROJECT_ROOT / jepa_checkpoint_dir / "best.pt")
if not os.path.isabs(JEPA_CHECKPOINT_PATH):
    JEPA_CHECKPOINT_PATH = str(PROJECT_ROOT / JEPA_CHECKPOINT_PATH)

ACTION_MODE = env_cfg.get("action_mode", "continuous")
ALLOW_SHORT = env_cfg.get("allow_short", True)
INCLUDE_WEALTH = env_cfg.get("include_wealth", True)
TRANSACTION_COST = env_cfg["transaction_cost"]

print("Model name:", model_name)
print("PPO checkpoint:", PPO_CHECKPOINT_PATH)
print("JEPA checkpoint:", JEPA_CHECKPOINT_PATH)
print("Action mode:", ACTION_MODE)


Model name: jepa6_ppo1
PPO checkpoint: C:\python\koulu\Gradu\checkpoints\jepa6_ppo1\ppo_5600000_steps.zip
JEPA checkpoint: C:\python\koulu\Gradu\checkpoints\jepa6_ppo1\jepa_step_700000.pt
Action mode: discrete_3


In [4]:
@dataclass
class EvalConfig:
    annual_trading_days: int = 252
    regular_hours_only: bool = True
    timeframe: str = "15min"
    flat_threshold: float = 1e-3


def _timeframe_to_minutes(timeframe: str) -> int:
    tf = timeframe.strip().lower()
    if tf.endswith("min"):
        return int(tf[:-3])
    if tf.endswith("h"):
        return int(tf[:-1]) * 60
    raise ValueError(f"Unsupported timeframe: {timeframe}")


def annualization_factor(cfg: EvalConfig) -> float:
    minutes_per_day = 390 if cfg.regular_hours_only else 24 * 60
    minutes = _timeframe_to_minutes(cfg.timeframe)
    bars_per_day = max(1, minutes_per_day // minutes)
    return bars_per_day * cfg.annual_trading_days


def action_to_weight(action) -> float:
    if ACTION_MODE == "discrete_3":
        discrete_actions = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
        idx = int(np.asarray(action).reshape(-1)[0])
        idx = int(np.clip(idx, 0, len(discrete_actions) - 1))
        w_t = float(discrete_actions[idx])
    else:
        w_t = float(np.clip(np.asarray(action).reshape(-1)[0], -1.0, 1.0))
    if not ALLOW_SHORT:
        w_t = max(0.0, w_t)
    return w_t


def compute_drawdown(equity: np.ndarray) -> float:
    if equity.size == 0:
        return float("nan")
    peak = np.maximum.accumulate(equity)
    drawdown = (equity - peak) / peak
    return float(np.min(drawdown))


def safe_sharpe(mean: float, std: float, ann_factor: float) -> float:
    if std <= 0 or np.isnan(std):
        return float("nan")
    return float(mean / std * np.sqrt(ann_factor))


def build_jepa_model(device: str, num_assets: int) -> JEPA:
    encoder_num_assets = num_assets if jepa_cfg.get("use_asset_embeddings", True) else None

    jepa_context_encoder = PatchTSTEncoder(
        patch_len=jepa_cfg["patch_len"],
        d_model=jepa_cfg["d_model"],
        n_features=jepa_cfg["n_features"],
        n_time_features=jepa_cfg["n_time_features"],
        nhead=jepa_cfg["nhead"],
        num_layers=jepa_cfg["num_layers"],
        dim_ff=jepa_cfg["dim_ff"],
        dropout=jepa_cfg["dropout"],
        add_cls=jepa_cfg.get("add_cls", True),
        pooling=jepa_cfg["pooling"],
        pred_len=jepa_cfg["pred_len"],
        num_assets=encoder_num_assets,
    )
    jepa_target_encoder = copy.deepcopy(jepa_context_encoder)

    jepa_model = JEPA(
        jepa_context_encoder,
        jepa_target_encoder,
        d_model=jepa_cfg["d_model"],
        ema_tau_min=jepa_cfg["ema_tau_min"],
        ema_tau_max=jepa_cfg["ema_tau_max"],
    )

    if os.path.exists(JEPA_CHECKPOINT_PATH):
        checkpoint = torch.load(JEPA_CHECKPOINT_PATH, map_location="cpu")
        missing, unexpected = jepa_model.load_state_dict(checkpoint["model"], strict=False)
        if missing:
            print(f"Missing keys in JEPA checkpoint: {missing}")
        if unexpected:
            print(f"Unexpected keys in JEPA checkpoint: {unexpected}")
    else:
        raise FileNotFoundError(f"JEPA checkpoint not found: {JEPA_CHECKPOINT_PATH}")

    for param in jepa_model.parameters():
        param.requires_grad = False
    jepa_model.eval()
    return jepa_model.to(device)


def load_ppo_model(model_path: str, device: str, policy_kwargs: Dict) -> PPOWithJEPA:
    try:
        return PPOWithJEPA.load(model_path, device=device)
    except Exception as exc:
        print(f"Primary PPO load failed ({exc}); retrying with custom policy_kwargs.")
        return PPOWithJEPA.load(model_path, device=device, custom_objects={"policy_kwargs": policy_kwargs})


def eval_asset(model: PPOWithJEPA, dataset: Dataset_Finance_MultiAsset, asset_id: str, cfg: EvalConfig) -> Dict[str, float]:
    asset_idx = dataset.asset_id_to_idx.get(asset_id, -1)
    X = dataset.data_x[asset_id]
    dates = dataset.dates[asset_id]
    ohlcv = dataset.ohlcv[asset_id]

    seq_len = dataset.seq_len
    pred_len = dataset.pred_len
    n_steps = len(X) - seq_len - pred_len
    if n_steps <= 0:
        return {}

    w_prev = 0.0
    wealth = 1.0
    rewards, asset_returns, positions, turnovers, equity = [], [], [], [], []

    for cursor in range(n_steps):
        x_context = X[cursor : cursor + seq_len].astype(np.float32)
        t_context = dates[cursor : cursor + seq_len].astype(np.float32)
        x_target = X[cursor + seq_len : cursor + seq_len + pred_len].astype(np.float32)
        t_target = dates[cursor + seq_len : cursor + seq_len + pred_len].astype(np.float32)

        obs = {
            "x_context": x_context,
            "t_context": t_context,
            "x_target": x_target,
            "t_target": t_target,
            "asset_id": np.int64(asset_idx),
            "w_prev": np.array([w_prev], dtype=np.float32),
        }
        if INCLUDE_WEALTH:
            obs["wealth_feats"] = np.array([np.log(wealth)], dtype=np.float32)

        action, _ = model.predict(obs, deterministic=DETERMINISTIC)
        w_t = action_to_weight(action)

        close_t = float(ohlcv[cursor + seq_len - 1][3])
        close_tp1 = float(ohlcv[cursor + seq_len][3])
        r_tp1 = float(np.log(close_tp1 / close_t))

        turnover = abs(w_t - w_prev)
        reward = w_t * r_tp1 - TRANSACTION_COST * turnover
        wealth *= float(np.exp(reward))

        rewards.append(reward)
        asset_returns.append(r_tp1)
        positions.append(w_t)
        turnovers.append(turnover)
        equity.append(wealth)
        w_prev = w_t

    rewards = np.asarray(rewards, dtype=np.float64)
    asset_returns = np.asarray(asset_returns, dtype=np.float64)
    positions = np.asarray(positions, dtype=np.float64)
    turnovers = np.asarray(turnovers, dtype=np.float64)
    equity = np.asarray(equity, dtype=np.float64)

    ann_factor = annualization_factor(cfg)
    mean_reward = float(np.mean(rewards)) if rewards.size else float("nan")
    std_reward = float(np.std(rewards, ddof=1)) if rewards.size > 1 else float("nan")

    total_log_return = float(np.sum(rewards)) if rewards.size else float("nan")
    total_return = float(np.exp(total_log_return) - 1.0) if rewards.size else float("nan")
    annualized_return = float(np.exp(mean_reward * ann_factor) - 1.0) if rewards.size else float("nan")
    annualized_vol = float(std_reward * np.sqrt(ann_factor)) if rewards.size > 1 else float("nan")
    sharpe = safe_sharpe(mean_reward, std_reward, ann_factor)

    downside = rewards[rewards < 0]
    downside_std = float(np.std(downside, ddof=1)) if downside.size > 1 else float("nan")
    sortino = safe_sharpe(mean_reward, downside_std, ann_factor)

    max_drawdown = compute_drawdown(equity)
    calmar = float(annualized_return / abs(max_drawdown)) if max_drawdown < 0 else float("nan")

    win_rate = float(np.mean(rewards > 0)) if rewards.size else float("nan")
    avg_turnover = float(np.mean(turnovers)) if turnovers.size else float("nan")
    total_turnover = float(np.sum(turnovers)) if turnovers.size else float("nan")
    avg_position = float(np.mean(positions)) if positions.size else float("nan")
    pos_std = float(np.std(positions, ddof=1)) if positions.size > 1 else float("nan")
    abs_pos = float(np.mean(np.abs(positions))) if positions.size else float("nan")

    flat_mask = np.abs(positions) <= cfg.flat_threshold
    long_mask = positions > cfg.flat_threshold
    short_mask = positions < -cfg.flat_threshold
    flat_frac = float(np.mean(flat_mask)) if positions.size else float("nan")
    long_frac = float(np.mean(long_mask)) if positions.size else float("nan")
    short_frac = float(np.mean(short_mask)) if positions.size else float("nan")

    trade_count = int(np.sum(np.abs(np.diff(positions)) > cfg.flat_threshold)) if positions.size > 1 else 0

    bh_mean = float(np.mean(asset_returns)) if asset_returns.size else float("nan")
    bh_std = float(np.std(asset_returns, ddof=1)) if asset_returns.size > 1 else float("nan")
    bh_total_return = float(np.exp(np.sum(asset_returns)) - 1.0) if asset_returns.size else float("nan")
    bh_annualized_return = float(np.exp(bh_mean * ann_factor) - 1.0) if asset_returns.size else float("nan")
    bh_annualized_vol = float(bh_std * np.sqrt(ann_factor)) if asset_returns.size > 1 else float("nan")
    bh_sharpe = safe_sharpe(bh_mean, bh_std, ann_factor)

    return {
        "asset_id": asset_id,
        "steps": int(n_steps),
        "total_return": total_return,
        "annualized_return": annualized_return,
        "annualized_volatility": annualized_vol,
        "sharpe": sharpe,
        "sortino": sortino,
        "max_drawdown": max_drawdown,
        "calmar": calmar,
        "avg_reward": mean_reward,
        "reward_volatility": std_reward,
        "win_rate": win_rate,
        "avg_turnover": avg_turnover,
        "total_turnover": total_turnover,
        "avg_position": avg_position,
        "position_std": pos_std,
        "avg_abs_position": abs_pos,
        "long_frac": long_frac,
        "short_frac": short_frac,
        "flat_frac": flat_frac,
        "trade_count": trade_count,
        "bh_total_return": bh_total_return,
        "bh_annualized_return": bh_annualized_return,
        "bh_annualized_volatility": bh_annualized_vol,
        "bh_sharpe": bh_sharpe,
    }


In [7]:

ticker_list_path = paths_cfg.get("ticker_list_path")
if not ticker_list_path:
    raise ValueError("Config is missing paths.ticker_list_path")

# Build dataset from the same config used in training
tickers_path = PROJECT_ROOT / ticker_list_path
tickers = load_tickers(str(tickers_path))
if not tickers:
    raise RuntimeError(f"No tickers loaded from {tickers_path}")

dataset_kwargs = {
    "root_path": dataset_cfg["root_path"],
    "data_path": dataset_cfg["data_path"],
    "start_date": dataset_cfg.get("start_date"),
    "split": "val",
    "size": [dataset_cfg["context_len"], dataset_cfg["target_len"]],
    "use_time_features": dataset_cfg.get("use_time_features", True),
    "rolling_window": dataset_cfg["rolling_window"],
    "train_split": dataset_cfg["train_split"],
    "test_split": dataset_cfg["test_split"],
    "regular_hours_only": dataset_cfg.get("regular_hours_only", True),
    "timeframe": dataset_cfg.get("timeframe", "15min"),
    "tickers": tickers,
}

asset_universe = load_asset_universe_from_checkpoint(JEPA_CHECKPOINT_PATH)
if asset_universe:
    dataset_kwargs["asset_universe"] = asset_universe

dataset_kwargs["tickers"] = tickers
print(f"Loaded {len(tickers)} tickers from {tickers_path}")
print(f"tickers: {tickers}")

print("Loading evaluation dataset...")
test_dataset = Dataset_Finance_MultiAsset(**dataset_kwargs)
if not test_dataset.asset_ids:
    raise RuntimeError("No assets found in the validation dataset.")

if MAX_ASSETS is not None:
    test_dataset.asset_ids = test_dataset.asset_ids[: int(MAX_ASSETS)]

print(f"Assets to evaluate: {len(test_dataset.asset_ids)}")


Loaded 10 tickers from C:\python\koulu\Gradu\configs\assets\tickers0.txt
tickers: ['CSCO', 'MRK', 'NKE', 'NVDA', 'DIS', 'AAPL', 'CAT', 'V', 'HON', 'AMZN']
Loading evaluation dataset...


  checkpoint = torch.load(path, map_location="cpu")


RuntimeError: No assets found in the validation dataset.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
num_asset_ids = int(getattr(test_dataset, "num_asset_ids", len(test_dataset.asset_ids)))

print("Loading JEPA model...")
jepa_model = build_jepa_model(device, num_assets=num_asset_ids)

policy_kwargs = dict(
    features_extractor_class=JEPAAuxFeatureExtractor,
    features_extractor_kwargs=dict(
        jepa_model=jepa_model,
        embedding_dim=jepa_cfg["d_model"],
        patch_len=jepa_cfg["patch_len"],
        patch_stride=jepa_cfg["patch_stride"],
        use_obs_targets=True,
        target_len=test_dataset.pred_len,
    ),
    net_arch=dict(pi=[256, 256], vf=[256, 256]),
)

print(f"Loading PPO model from {PPO_CHECKPOINT_PATH}...")
model = load_ppo_model(PPO_CHECKPOINT_PATH, device=device, policy_kwargs=policy_kwargs)
model.policy.eval()


In [None]:
eval_cfg = EvalConfig(
    annual_trading_days=252,
    regular_hours_only=dataset_kwargs.get("regular_hours_only", True),
    timeframe=dataset_kwargs.get("timeframe", "15min"),
)

results: List[Dict[str, float]] = []
for idx, asset_id in enumerate(test_dataset.asset_ids, start=1):
    print(f"[{idx}/{len(test_dataset.asset_ids)}] Evaluating {asset_id}...")
    metrics = eval_asset(model, test_dataset, asset_id, eval_cfg)
    if metrics:
        results.append(metrics)

if not results:
    raise RuntimeError("No evaluation results produced.")

df = pd.DataFrame(results).sort_values("asset_id")
df.head()


In [None]:
# Save outputs
os.makedirs(PROJECT_ROOT / log_root, exist_ok=True)
metrics_path = PROJECT_ROOT / log_root / f"{model_name}_test_metrics.csv"
summary_path = PROJECT_ROOT / log_root / f"{model_name}_test_summary.csv"

df.to_csv(metrics_path, index=False)
summary = df.drop(columns=["asset_id"]).agg(["mean", "median"])
summary.to_csv(summary_path)

print(f"Saved per-asset metrics to {metrics_path}")
print(f"Saved summary to {summary_path}")
summary
