# PPO Model Testing

Evaluate a trained PPO+JEPA model using a single training config file.


In [43]:
from __future__ import annotations

import os
import sys
import io
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
import matplotlib.pyplot as plt
import copy

import numpy as np
import pandas as pd
import torch

# Resolve project root robustly when notebook is launched from different cwd
def find_project_root(start: Path) -> Path:
    p = start.resolve()
    for candidate in [p, *p.parents]:
        if (candidate / "src").exists() and (candidate / "configs").exists():
            return candidate
    raise RuntimeError("Could not locate project root containing src/ and configs/")

PROJECT_ROOT = find_project_root(Path.cwd())
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
if str(PROJECT_ROOT / "src") not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT / "src"))

from config.config_utils import load_json_config
from Datasets.multi_asset_dataset import Dataset_Finance_MultiAsset
from Training.sb3_jepa_ppo import JEPAAuxFeatureExtractor, PPOWithJEPA
from models.jepa.jepa import JEPA
from models.time_series.patchTransformer import PatchTSTEncoder

print(f"Project root: {PROJECT_ROOT}")


Project root: C:\python\koulu\Gradu


In [44]:
# -----------------------------
# User parameters
# -----------------------------
# Config used for training this PPO model
PPO_CONFIG_PATH = "configs/ppo_jepa_final_1.json"

# Optional overrides (set to None to auto-resolve from config)
PPO_CHECKPOINT_PATH = "checkpoints/jepa6_ppo_final1/ppo_4800000_steps.zip"
# Optional: used only for asset_universe lookup, not for JEPA model weights
JEPA_CHECKPOINT_PATH = "checkpoints/jepa6_ppo_final1/jepa_step_300000.pt"

# If None, evaluate all assets available in validation split
MAX_ASSETS = None

# Deterministic policy during evaluation
DETERMINISTIC = True


In [45]:
def get_latest_ppo_checkpoint(checkpoint_dir: str) -> str | None:
    if not os.path.isdir(checkpoint_dir):
        return None
    ckpts = []
    for fname in os.listdir(checkpoint_dir):
        if fname.startswith("ppo_") and fname.endswith("_steps.zip"):
            ckpts.append(os.path.join(checkpoint_dir, fname))
    if not ckpts:
        return None
    ckpts.sort(key=lambda p: os.path.getmtime(p))
    return ckpts[-1]


def load_tickers(path: str) -> list | None:
    if not path or not os.path.exists(path):
        return None
    with open(path, "r", encoding="utf-8") as f:
        tickers = [line.strip() for line in f if line.strip()]
    return tickers or None


def load_asset_universe_from_checkpoint(path: str | None) -> list | None:
    if not path or not os.path.exists(path):
        return None
    try:
        checkpoint = torch.load(path, map_location="cpu")
    except Exception:
        return None
    asset_universe = checkpoint.get("asset_universe")
    return list(asset_universe) if asset_universe else None


cfg = load_json_config(str(PROJECT_ROOT / PPO_CONFIG_PATH), "", str(PROJECT_ROOT / "notebooks" / "ppo_model_test.ipynb"))

model_name = cfg["model_name"]
paths_cfg = cfg["paths"]
dataset_cfg = cfg["dataset"]
env_cfg = cfg["env"]
ppo_cfg = cfg["ppo"]
jepa_cfg = cfg["jepa_model"]

checkpoint_root = paths_cfg.get("checkpoint_root", "checkpoints")
log_root = paths_cfg.get("log_root", "logs")
ppo_checkpoint_dir = str(PROJECT_ROOT / checkpoint_root / model_name)

if PPO_CHECKPOINT_PATH is None:
    PPO_CHECKPOINT_PATH = get_latest_ppo_checkpoint(ppo_checkpoint_dir)
if PPO_CHECKPOINT_PATH is None:
    raise FileNotFoundError(f"No PPO checkpoint found under {ppo_checkpoint_dir}")
if not os.path.isabs(PPO_CHECKPOINT_PATH):
    PPO_CHECKPOINT_PATH = str(PROJECT_ROOT / PPO_CHECKPOINT_PATH)

if JEPA_CHECKPOINT_PATH is None and paths_cfg.get("jepa_checkpoint_dir"):
    jepa_checkpoint_dir = paths_cfg["jepa_checkpoint_dir"]
    JEPA_CHECKPOINT_PATH = str(PROJECT_ROOT / jepa_checkpoint_dir / "best.pt")
if JEPA_CHECKPOINT_PATH is not None and not os.path.isabs(JEPA_CHECKPOINT_PATH):
    JEPA_CHECKPOINT_PATH = str(PROJECT_ROOT / JEPA_CHECKPOINT_PATH)

ACTION_MODE = env_cfg.get("action_mode", "continuous")
ALLOW_SHORT = env_cfg.get("allow_short", True)
INCLUDE_WEALTH = env_cfg.get("include_wealth", True)
TRANSACTION_COST = 0.0005

print("Model name:", model_name)
print("PPO checkpoint:", PPO_CHECKPOINT_PATH)
print("JEPA checkpoint:", JEPA_CHECKPOINT_PATH)
print("Action mode:", ACTION_MODE)
print("Asset universe:", paths_cfg["asset_universe_path"])


Model name: jepa6_ppo_final1
PPO checkpoint: C:\python\koulu\Gradu\checkpoints\jepa6_ppo_final1\ppo_4800000_steps.zip
JEPA checkpoint: C:\python\koulu\Gradu\checkpoints\jepa6_ppo_final1\jepa_step_300000.pt
Action mode: discrete_3
Asset universe: None


In [46]:
@dataclass
class EvalConfig:
    annual_trading_days: int = 252
    regular_hours_only: bool = True
    timeframe: str = "15min"
    flat_threshold: float = 1e-3


def _timeframe_to_minutes(timeframe: str) -> int:
    tf = timeframe.strip().lower()
    if tf.endswith("min"):
        return int(tf[:-3])
    if tf.endswith("h"):
        return int(tf[:-1]) * 60
    raise ValueError(f"Unsupported timeframe: {timeframe}")


def annualization_factor(cfg: EvalConfig) -> float:
    minutes_per_day = 390 if cfg.regular_hours_only else 24 * 60
    minutes = _timeframe_to_minutes(cfg.timeframe)
    bars_per_day = max(1, minutes_per_day // minutes)
    return bars_per_day * cfg.annual_trading_days


def action_to_weight(action) -> float:
    if ACTION_MODE == "discrete_3":
        discrete_actions = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
        idx = int(np.asarray(action).reshape(-1)[0])
        idx = int(np.clip(idx, 0, len(discrete_actions) - 1))
        w_t = float(discrete_actions[idx])
    else:
        w_t = float(np.clip(np.asarray(action).reshape(-1)[0], -1.0, 1.0))
    if not ALLOW_SHORT:
        w_t = max(0.0, w_t)
    return w_t


def compute_drawdown(equity: np.ndarray) -> float:
    if equity.size == 0:
        return float("nan")
    peak = np.maximum.accumulate(equity)
    drawdown = (equity - peak) / peak
    return float(np.min(drawdown))


def safe_sharpe(mean: float, std: float, ann_factor: float) -> float:
    if std <= 0 or np.isnan(std):
        return float("nan")
    return float(mean / std * np.sqrt(ann_factor))


def build_jepa_model(device: str, num_assets: int) -> JEPA:
    encoder_num_assets = num_assets if jepa_cfg.get("use_asset_embeddings", True) else None

    jepa_context_encoder = PatchTSTEncoder(
        patch_len=jepa_cfg["patch_len"],
        d_model=jepa_cfg["d_model"],
        n_features=jepa_cfg["n_features"],
        n_time_features=jepa_cfg["n_time_features"],
        nhead=jepa_cfg["nhead"],
        num_layers=jepa_cfg["num_layers"],
        dim_ff=jepa_cfg["dim_ff"],
        dropout=jepa_cfg["dropout"],
        add_cls=jepa_cfg.get("add_cls", True),
        pooling=jepa_cfg["pooling"],
        pred_len=jepa_cfg["pred_len"],
        num_assets=encoder_num_assets,
    )
    jepa_target_encoder = copy.deepcopy(jepa_context_encoder)

    jepa_model = JEPA(
        jepa_context_encoder,
        jepa_target_encoder,
        d_model=jepa_cfg["d_model"],
        ema_tau_min=jepa_cfg["ema_tau_min"],
        ema_tau_max=jepa_cfg["ema_tau_max"],
    )

    for param in jepa_model.parameters():
        param.requires_grad = False
    jepa_model.eval()
    return jepa_model.to(device)


def extract_jepa_state_dict_from_ppo_zip(model_path: str) -> Dict[str, torch.Tensor]:
    with zipfile.ZipFile(model_path, "r") as zf:
        with zf.open("policy.pth", "r") as f:
            policy_state = torch.load(io.BytesIO(f.read()), map_location="cpu")

    prefix = "features_extractor.jepa_model."
    jepa_state = {}
    for k, v in policy_state.items():
        if k.startswith(prefix):
            jepa_state[k[len(prefix):]] = v
    if not jepa_state:
        raise RuntimeError("No JEPA weights found in PPO zip policy state.")
    return jepa_state


def load_ppo_model(model_path: str, device: str, policy_kwargs: Dict) -> PPOWithJEPA:
    try:
        print("Loading PPO (and embedded JEPA) directly from PPO zip...")
        return PPOWithJEPA.load(model_path, device=device)
    except Exception as exc:
        print(f"Primary PPO load failed ({exc}); retrying with custom policy_kwargs and JEPA weights from PPO zip.")

        fx_kwargs = policy_kwargs.get("features_extractor_kwargs", {})
        jepa_model = fx_kwargs.get("jepa_model")
        if jepa_model is None:
            raise RuntimeError("Fallback requires policy_kwargs.features_extractor_kwargs.jepa_model")

        jepa_state = extract_jepa_state_dict_from_ppo_zip(model_path)
        missing, unexpected = jepa_model.load_state_dict(jepa_state, strict=False)
        if missing:
            print(f"Missing keys when loading JEPA from PPO zip: {missing}")
        if unexpected:
            print(f"Unexpected keys when loading JEPA from PPO zip: {unexpected}")

        return PPOWithJEPA.load(model_path, device=device, custom_objects={"policy_kwargs": policy_kwargs})


def eval_asset(model: PPOWithJEPA, dataset: Dataset_Finance_MultiAsset, asset_id: str, cfg: EvalConfig) -> Dict[str, float]:
    asset_idx = dataset.asset_id_to_idx.get(asset_id, -1)
    X = dataset.data_x[asset_id]
    dates = dataset.dates[asset_id]
    ohlcv = dataset.ohlcv[asset_id]

    seq_len = dataset.seq_len
    pred_len = dataset.pred_len
    n_steps = len(X) - seq_len - pred_len
    if n_steps <= 0:
        return {}

    w_prev = 0.0
    wealth = 1.0
    rewards, asset_returns, positions, turnovers, equity = [], [], [], [], []

    for cursor in range(n_steps):
        x_context = X[cursor : cursor + seq_len].astype(np.float32)
        t_context = dates[cursor : cursor + seq_len].astype(np.float32)
        x_target = X[cursor + seq_len : cursor + seq_len + pred_len].astype(np.float32)
        t_target = dates[cursor + seq_len : cursor + seq_len + pred_len].astype(np.float32)

        obs = {
            "x_context": x_context,
            "t_context": t_context,
            "x_target": x_target,
            "t_target": t_target,
            "asset_id": np.int64(asset_idx),
            "w_prev": np.array([w_prev], dtype=np.float32),
        }
        if INCLUDE_WEALTH:
            obs["wealth_feats"] = np.array([np.log(wealth)], dtype=np.float32)

        action, _ = model.predict(obs, deterministic=DETERMINISTIC)
        w_t = action_to_weight(action)

        close_t = float(ohlcv[cursor + seq_len - 1][3])
        close_tp1 = float(ohlcv[cursor + seq_len][3])
        r_tp1 = float(np.log(close_tp1 / close_t))

        turnover = abs(w_t - w_prev)
        reward = w_t * r_tp1 - TRANSACTION_COST * turnover
        wealth *= float(np.exp(reward))

        rewards.append(reward)
        asset_returns.append(r_tp1)
        positions.append(w_t)
        turnovers.append(turnover)
        equity.append(wealth)
        w_prev = w_t

    rewards = np.asarray(rewards, dtype=np.float64)
    asset_returns = np.asarray(asset_returns, dtype=np.float64)
    positions = np.asarray(positions, dtype=np.float64)
    turnovers = np.asarray(turnovers, dtype=np.float64)
    equity = np.asarray(equity, dtype=np.float64)

    ann_factor = annualization_factor(cfg)
    mean_reward = float(np.mean(rewards)) if rewards.size else float("nan")
    std_reward = float(np.std(rewards, ddof=1)) if rewards.size > 1 else float("nan")

    total_log_return = float(np.sum(rewards)) if rewards.size else float("nan")
    total_return = float(np.exp(total_log_return) - 1.0) if rewards.size else float("nan")
    annualized_return = float(np.exp(mean_reward * ann_factor) - 1.0) if rewards.size else float("nan")
    annualized_vol = float(std_reward * np.sqrt(ann_factor)) if rewards.size > 1 else float("nan")
    sharpe = safe_sharpe(mean_reward, std_reward, ann_factor)

    downside = rewards[rewards < 0]
    downside_std = float(np.std(downside, ddof=1)) if downside.size > 1 else float("nan")
    sortino = safe_sharpe(mean_reward, downside_std, ann_factor)

    max_drawdown = compute_drawdown(equity)
    calmar = float(annualized_return / abs(max_drawdown)) if max_drawdown < 0 else float("nan")

    win_rate = float(np.mean(rewards > 0)) if rewards.size else float("nan")
    avg_turnover = float(np.mean(turnovers)) if turnovers.size else float("nan")
    total_turnover = float(np.sum(turnovers)) if turnovers.size else float("nan")
    avg_position = float(np.mean(positions)) if positions.size else float("nan")
    pos_std = float(np.std(positions, ddof=1)) if positions.size > 1 else float("nan")
    abs_pos = float(np.mean(np.abs(positions))) if positions.size else float("nan")

    flat_mask = np.abs(positions) <= cfg.flat_threshold
    long_mask = positions > cfg.flat_threshold
    short_mask = positions < -cfg.flat_threshold
    flat_frac = float(np.mean(flat_mask)) if positions.size else float("nan")
    long_frac = float(np.mean(long_mask)) if positions.size else float("nan")
    short_frac = float(np.mean(short_mask)) if positions.size else float("nan")

    trade_count = int(np.sum(np.abs(np.diff(positions)) > cfg.flat_threshold)) if positions.size > 1 else 0

    bh_mean = float(np.mean(asset_returns)) if asset_returns.size else float("nan")
    bh_std = float(np.std(asset_returns, ddof=1)) if asset_returns.size > 1 else float("nan")
    bh_total_return = float(np.exp(np.sum(asset_returns)) - 1.0) if asset_returns.size else float("nan")
    bh_annualized_return = float(np.exp(bh_mean * ann_factor) - 1.0) if asset_returns.size else float("nan")
    bh_annualized_vol = float(bh_std * np.sqrt(ann_factor)) if asset_returns.size > 1 else float("nan")
    bh_sharpe = safe_sharpe(bh_mean, bh_std, ann_factor)

    return {
        "asset_id": asset_id,
        "steps": int(n_steps),
        "total_return": total_return,
        "annualized_return": annualized_return,
        "annualized_volatility": annualized_vol,
        "sharpe": sharpe,
        "sortino": sortino,
        "max_drawdown": max_drawdown,
        "calmar": calmar,
        "avg_reward": mean_reward,
        "reward_volatility": std_reward,
        "win_rate": win_rate,
        "avg_turnover": avg_turnover,
        "total_turnover": total_turnover,
        "avg_position": avg_position,
        "position_std": pos_std,
        "avg_abs_position": abs_pos,
        "long_frac": long_frac,
        "short_frac": short_frac,
        "flat_frac": flat_frac,
        "trade_count": trade_count,
        "bh_total_return": bh_total_return,
        "bh_annualized_return": bh_annualized_return,
        "bh_annualized_volatility": bh_annualized_vol,
        "bh_sharpe": bh_sharpe,
    }



In [47]:
def _resolve_project_path(path_value: str | None) -> str | None:
    if path_value is None:
        return None
    p = Path(path_value)
    if p.is_absolute():
        return str(p)
    return str((PROJECT_ROOT / p).resolve())



ticker_list_path = paths_cfg.get("ticker_list_path")
if not ticker_list_path:
    raise ValueError("Config is missing paths.ticker_list_path")

# Build dataset from the same config used in training
tickers_path = PROJECT_ROOT / ticker_list_path
tickers = load_tickers(str(tickers_path))
if not tickers:
    raise RuntimeError(f"No tickers loaded from {tickers_path}")

dataset_kwargs = {
    "root_path": _resolve_project_path(dataset_cfg["root_path"]),
    "data_path": dataset_cfg["data_path"],
    "start_date": dataset_cfg.get("start_date"),
    "split": "val",
    "size": [dataset_cfg["context_len"], dataset_cfg["target_len"]],
    "use_time_features": dataset_cfg.get("use_time_features", True),
    "rolling_window": dataset_cfg["rolling_window"],
    "train_split": dataset_cfg["train_split"],
    "test_split": dataset_cfg["test_split"],
    "regular_hours_only": dataset_cfg.get("regular_hours_only", True),
    "timeframe": dataset_cfg.get("timeframe", "15min"),
    "tickers": tickers,
}

asset_universe = load_asset_universe_from_checkpoint(JEPA_CHECKPOINT_PATH)
if asset_universe:
    dataset_kwargs["asset_universe"] = asset_universe

dataset_kwargs["tickers"] = tickers
print(f"Loaded {len(tickers)} tickers from {tickers_path}")
print(f"tickers: {tickers}")

print("Loading evaluation dataset...")
test_dataset = Dataset_Finance_MultiAsset(**dataset_kwargs)
if not test_dataset.asset_ids:
    raise RuntimeError("No assets found in the validation dataset.")

if MAX_ASSETS is not None:
    test_dataset.asset_ids = test_dataset.asset_ids[: int(MAX_ASSETS)]

print(f"Assets to evaluate: {len(test_dataset.asset_ids)}")


Loaded 11 tickers from C:\python\koulu\Gradu\configs\assets\tickers1.txt
tickers: ['AMZN', 'KO', 'DIS', 'V', 'SPY', 'NKE', 'CSCO', 'JPM', 'CAT', 'AMGN', 'DIA']
Loading evaluation dataset...


  checkpoint = torch.load(path, map_location="cpu")


[Dataset_Finance_MultiAsset] Global date splits: train_end=2024-08-01 15:15:00+00:00 val_end=2025-04-29 16:30:00+00:00 n_dates=33178 n_train=23224 n_val=4978 n_test=4976
Assets to evaluate: 11


In [48]:
device = "cuda" if torch.cuda.is_available() else "cpu"
num_asset_ids = int(getattr(test_dataset, "num_asset_ids", len(test_dataset.asset_ids)))

print("Loading JEPA model...")
jepa_model = build_jepa_model(device, num_assets=num_asset_ids)

policy_kwargs = dict(
    features_extractor_class=JEPAAuxFeatureExtractor,
    features_extractor_kwargs=dict(
        jepa_model=jepa_model,
        embedding_dim=jepa_cfg["d_model"],
        patch_len=jepa_cfg["patch_len"],
        patch_stride=jepa_cfg["patch_stride"],
        use_obs_targets=True,
        target_len=test_dataset.pred_len,
    ),
    net_arch=dict(pi=[256, 256], vf=[256, 256]),
)

print(f"Loading PPO model from {PPO_CHECKPOINT_PATH}...")
model = load_ppo_model(PPO_CHECKPOINT_PATH, device=device, policy_kwargs=policy_kwargs)
model.policy.eval()


Loading JEPA model...
Loading PPO model from C:\python\koulu\Gradu\checkpoints\jepa6_ppo_final1\ppo_4800000_steps.zip...


  checkpoint = torch.load(JEPA_CHECKPOINT_PATH, map_location="cpu")


MultiInputActorCriticPolicy(
  (features_extractor): JEPAAuxFeatureExtractor(
    (jepa_model): JEPA(
      (context_enc): PatchTSTEncoder(
        (proj_price): Linear(in_features=72, out_features=192, bias=True)
        (proj_time): Linear(in_features=32, out_features=192, bias=True)
        (posenc): PositionalEncoding()
        (encoder): TransformerEncoder(
          (layers): ModuleList(
            (0-3): 4 x TransformerEncoderLayer(
              (self_attn): MultiheadAttention(
                (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
              )
              (linear1): Linear(in_features=192, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (linear2): Linear(in_features=768, out_features=192, bias=True)
              (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
              (dropou

In [49]:
eval_cfg = EvalConfig(
    annual_trading_days=252,
    regular_hours_only=dataset_kwargs.get("regular_hours_only", True),
    timeframe=dataset_kwargs.get("timeframe", "15min"),
)

results: List[Dict[str, float]] = []
for idx, asset_id in enumerate(test_dataset.asset_ids, start=1):
    print(f"[{idx}/{len(test_dataset.asset_ids)}] Evaluating {asset_id}...")
    metrics = eval_asset(model, test_dataset, asset_id, eval_cfg)
    if metrics:
        results.append(metrics)

if not results:
    raise RuntimeError("No evaluation results produced.")

df = pd.DataFrame(results).sort_values("asset_id")
df.head(11)


[1/11] Evaluating AMZN...
[2/11] Evaluating KO...
[3/11] Evaluating DIS...
[4/11] Evaluating V...
[5/11] Evaluating SPY...
[6/11] Evaluating NKE...
[7/11] Evaluating CSCO...
[8/11] Evaluating JPM...
[9/11] Evaluating CAT...
[10/11] Evaluating AMGN...
[11/11] Evaluating DIA...


Unnamed: 0,asset_id,steps,total_return,annualized_return,annualized_volatility,sharpe,sortino,max_drawdown,calmar,avg_reward,...,position_std,avg_abs_position,long_frac,short_frac,flat_frac,trade_count,bh_total_return,bh_annualized_return,bh_annualized_volatility,bh_sharpe
9,AMGN,4416,0.078785,0.11909,0.29453,0.38202,0.419802,-0.246258,0.483601,1.7e-05,...,0.963226,0.939312,0.415534,0.523777,0.060688,84,-0.146664,-0.20968,0.305772,-0.769582
0,AMZN,4417,-0.263604,-0.364847,0.354345,-1.280925,-1.407798,-0.396878,-0.919292,-6.9e-05,...,0.921993,0.938646,0.320353,0.618293,0.061354,73,0.104821,0.159358,0.362985,0.407364
8,CAT,4416,-0.158503,-0.225893,0.347918,-0.735937,-0.817875,-0.271705,-0.83139,-3.9e-05,...,0.950269,0.944067,0.573596,0.370471,0.055933,81,-0.130198,-0.186948,0.354228,-0.584258
6,CSCO,4416,-0.267382,-0.369741,0.234086,-1.972031,-2.286557,-0.311985,-1.185122,-7e-05,...,0.97896,0.961051,0.507473,0.453578,0.038949,61,0.130017,0.198842,0.235809,0.769079
10,DIA,4408,-0.082453,-0.120064,0.193355,-0.661511,-0.761481,-0.2249,-0.533855,-2e-05,...,0.87061,0.979809,0.725499,0.25431,0.020191,47,-0.03181,-0.046914,0.195193,-0.246168
2,DIS,4370,-0.24263,-0.340758,0.295608,-1.409521,-1.432984,-0.281929,-1.208668,-6.4e-05,...,0.9442,0.908924,0.520824,0.388101,0.091076,84,0.007318,0.010991,0.306695,0.035642
7,JPM,4416,-0.372595,-0.499248,0.321658,-2.150243,-2.160641,-0.495557,-1.007449,-0.000106,...,0.953801,0.971241,0.361413,0.609828,0.028759,53,0.093133,0.141246,0.323585,0.408302
1,KO,4416,0.110848,0.168792,0.186566,0.836008,1.101752,-0.116962,1.44314,2.4e-05,...,0.988402,0.976902,0.495245,0.481658,0.023098,61,0.003404,0.005055,0.18784,0.026841
5,NKE,4416,-0.20096,-0.283128,0.417414,-0.797429,-0.814938,-0.349062,-0.811112,-5.1e-05,...,0.733227,0.9649,0.80933,0.155571,0.0351,47,-0.319929,-0.435633,0.419046,-1.365125
4,SPY,4417,-0.092966,-0.134752,0.212124,-0.682334,-0.747779,-0.240745,-0.55973,-2.2e-05,...,0.866528,0.970115,0.719266,0.250849,0.029885,47,-0.016873,-0.024926,0.213149,-0.118423


In [52]:
ret_analysis = df[["asset_id", "total_return", "bh_total_return", "sharpe", "bh_sharpe"]].copy()
ret_analysis["ret_diff"] = ret_analysis["total_return"] - ret_analysis["bh_total_return"]
ret_analysis["sharpe_diff"] = ret_analysis["sharpe"] - ret_analysis["bh_sharpe"]
ret_analysis

Unnamed: 0,asset_id,total_return,bh_total_return,sharpe,bh_sharpe,ret_diff,sharpe_diff
9,AMGN,0.078785,-0.146664,0.38202,-0.769582,0.225449,1.151602
0,AMZN,-0.263604,0.104821,-1.280925,0.407364,-0.368426,-1.688289
8,CAT,-0.158503,-0.130198,-0.735937,-0.584258,-0.028305,-0.151679
6,CSCO,-0.267382,0.130017,-1.972031,0.769079,-0.397399,-2.74111
10,DIA,-0.082453,-0.03181,-0.661511,-0.246168,-0.050643,-0.415343
2,DIS,-0.24263,0.007318,-1.409521,0.035642,-0.249948,-1.445163
7,JPM,-0.372595,0.093133,-2.150243,0.408302,-0.465728,-2.558544
1,KO,0.110848,0.003404,0.836008,0.026841,0.107444,0.809166
5,NKE,-0.20096,-0.319929,-0.797429,-1.365125,0.118969,0.567695
4,SPY,-0.092966,-0.016873,-0.682334,-0.118423,-0.076093,-0.563911


In [53]:
# Save outputs
os.makedirs(PROJECT_ROOT / log_root, exist_ok=True)
metrics_path = PROJECT_ROOT / log_root / f"{model_name}_test_metrics.csv"
summary_path = PROJECT_ROOT / log_root / f"{model_name}_test_summary.csv"

df.to_csv(metrics_path, index=False)
summary = df.drop(columns=["asset_id"]).agg(["mean", "median"])
summary.to_csv(summary_path)

print(f"Saved per-asset metrics to {metrics_path}")
print(f"Saved summary to {summary_path}")
summary


Saved per-asset metrics to C:\python\koulu\Gradu\logs\jepa6_ppo_final1_test_metrics.csv
Saved summary to C:\python\koulu\Gradu\logs\jepa6_ppo_final1_test_summary.csv


Unnamed: 0,steps,total_return,annualized_return,annualized_volatility,sharpe,sortino,max_drawdown,calmar,avg_reward,reward_volatility,...,position_std,avg_abs_position,long_frac,short_frac,flat_frac,trade_count,bh_total_return,bh_annualized_return,bh_annualized_volatility,bh_sharpe
mean,4411.272727,-0.147141,-0.203017,0.280635,-0.8501,-0.9102,-0.287118,-0.541034,-3.9e-05,0.003467,...,0.920517,0.956352,0.549751,0.4066,0.043648,65.363636,-0.00638,-0.001657,0.28581,-0.011028
median,4416.0,-0.158503,-0.225893,0.29453,-0.797429,-0.817875,-0.271705,-0.821503,-3.9e-05,0.003639,...,0.950269,0.9649,0.520824,0.388101,0.0351,61.0,0.003404,0.005055,0.305772,0.026841
