# PPO Training for Stock Trading

This notebook implements Proximal Policy Optimization (PPO) for training a reinforcement learning agent to trade stocks.

## Overview
- **Environment**: Stock trading environment using historical price data
- **Algorithm**: PPO (Proximal Policy Optimization)
- **Model**: Actor-Critic network (MLP or CNN)
- **Features**: Volume, extra features (volatility, ATR-like), chronological train/val split


## 1. Imports and Setup


In [39]:
import os
import time
import random

import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from tensorboardX import SummaryWriter

# Import our stock trading environment package
from stock_trading_env import (
    StocksEnv,
    load_many_from_dir,
    split_many_by_ratio,
    Actions,
)

# ---------- Reproducibility ----------
def set_seed(seed: int, deterministic: bool = True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Pick a default seed (match your CLI default if you had one)
SEED = 0
set_seed(SEED, deterministic=True)

# ---------- Device ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Output directories ----------
os.makedirs("runs", exist_ok=True)
os.makedirs("saves", exist_ok=True)
os.makedirs("logs", exist_ok=True)

print("✓ Imports successful")
print(f"Seed: {SEED} (deterministic=True)")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


✓ Imports successful
Seed: 0 (deterministic=True)
PyTorch version: 2.9.0+cpu
CUDA available: False
Device: cpu


## 2. PPO Model Definitions

Define the Actor-Critic networks for PPO.


In [40]:
# Cell 2: PPO Actor-Critic models (DISCRETE actions) — matches earlier repo logic

import torch
import torch.nn as nn


class ActorCriticMLP(nn.Module):
    """
    PPO Actor-Critic for DISCRETE actions (vector observation).

    Input:  x (B, obs_dim)
    Output: logits (B, n_actions), value (B,)
      - logits are unnormalized scores for torch.distributions.Categorical(logits=logits)
      - value is V(s) for GAE/advantage estimation
    """

    def __init__(self, obs_dim: int, n_actions: int, hidden: int = 256):
        super().__init__()

        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )

        self.policy_head = nn.Linear(hidden, n_actions)
        self.value_head = nn.Linear(hidden, 1)

        self._init_weights()

    def _init_weights(self):
        # Orthogonal init is common for PPO
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1.0)
                nn.init.constant_(m.bias, 0.0)

        # Smaller gain for the final policy head helps early training stability
        nn.init.orthogonal_(self.policy_head.weight, gain=0.01)
        nn.init.constant_(self.policy_head.bias, 0.0)

    def forward(self, x: torch.Tensor):
        z = self.shared(x)
        logits = self.policy_head(z)
        value = self.value_head(z).squeeze(-1)
        return logits, value


class ActorCriticConv1D(nn.Module):
    """
    PPO Actor-Critic for State1D observations (CNN over time bars).

    Expected env State1D shape (C, T) where:
      - C = (3 + (1 if volumes else 0)) + 2  -> typically 6 if volumes=True
      - T = bars_count (e.g., 10)

    Input:  x (B, C, T)
    Output: logits (B, n_actions), value (B,)
    """

    def __init__(self, n_actions: int, bars_count: int, volumes: bool = True, hidden: int = 256):
        super().__init__()

        # Match env.State1D channel logic:
        # base channels: 3 + (1 if volumes else 0)
        # +2 for (have_position, unrealized_return)
        in_channels = (3 + (1 if volumes else 0)) + 2

        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),  # -> (B, 64 * bars_count)
        )

        conv_out = 64 * bars_count

        self.shared = nn.Sequential(
            nn.Linear(conv_out, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )

        self.policy_head = nn.Linear(hidden, n_actions)
        self.value_head = nn.Linear(hidden, 1)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv1d)):
                nn.init.orthogonal_(m.weight, gain=1.0)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

        nn.init.orthogonal_(self.policy_head.weight, gain=0.01)
        nn.init.constant_(self.policy_head.bias, 0.0)

    def forward(self, x: torch.Tensor):
        feat = self.conv(x)
        z = self.shared(feat)
        logits = self.policy_head(z)
        value = self.value_head(z).squeeze(-1)
        return logits, value


print("✓ PPO models defined (MLP + Conv1D)")
print("  - Use ActorCriticMLP for vector obs")
print("  - Use ActorCriticConv1D only if env returns State1D (C,T), e.g. (6, bars)")

✓ PPO models defined (MLP + Conv1D)
  - Use ActorCriticMLP for vector obs
  - Use ActorCriticConv1D only if env returns State1D (C,T), e.g. (6, bars)


## 3. PPO Buffer and Utilities

Define the rollout buffer for collecting experience and computing GAE advantages.


In [41]:
# Cell 3 (updated): RolloutBuffer + PPOBatch
# Fixes/robustness:
# - Safer dtype/device handling
# - Enforces add() obs shape
# - Handles truncated/terminated properly via done float
# - Supports 1D (vector) and 2D (C,T) observations cleanly
# - Compute GAE in float32, normalize advantages safely
# - Optional pin_memory for faster CPU->GPU transfer

from dataclasses import dataclass
from typing import Iterator, Optional, Tuple, Union

import numpy as np
import torch


@dataclass
class PPOBatch:
    """One minibatch used during PPO updates."""
    obs: torch.Tensor
    actions: torch.Tensor
    old_logprobs: torch.Tensor
    advantages: torch.Tensor
    returns: torch.Tensor
    old_values: torch.Tensor


class RolloutBuffer:
    """
    Fixed-length rollout buffer for PPO.

    Stores:
      obs[t], action[t], reward[t], done[t], value[t], logprob[t]

    After calling compute_gae():
      advantages[t], returns[t] are filled.

    Notes:
      - done should be True when episode ended for ANY reason:
        done = terminated OR truncated (Gymnasium).
      - obs can be vector (obs_dim,) or State1D (C,T).
    """

    def __init__(
        self,
        obs_shape: Union[int, Tuple[int, ...]],
        size: int,
        device: Union[str, torch.device] = "cpu",
        obs_dtype: np.dtype = np.float32,
        pin_memory: bool = False,
    ):
        if isinstance(obs_shape, int):
            obs_shape = (obs_shape,)
        self.obs_shape = tuple(obs_shape)
        self.size = int(size)
        self.device = torch.device(device) if not isinstance(device, torch.device) else device
        self.pin_memory = bool(pin_memory) and (self.device.type == "cuda")

        # Storage (numpy)
        self.obs = np.zeros((self.size, *self.obs_shape), dtype=obs_dtype)
        self.actions = np.zeros((self.size,), dtype=np.int64)
        self.rewards = np.zeros((self.size,), dtype=np.float32)
        self.dones = np.zeros((self.size,), dtype=np.float32)   # 1.0 if done else 0.0
        self.values = np.zeros((self.size,), dtype=np.float32)
        self.logprobs = np.zeros((self.size,), dtype=np.float32)

        # Computed after rollout
        self.advantages = np.zeros((self.size,), dtype=np.float32)
        self.returns = np.zeros((self.size,), dtype=np.float32)

        self.ptr = 0
        self.full = False

    def reset(self):
        self.ptr = 0
        self.full = False

    def add(
        self,
        obs: np.ndarray,
        action: int,
        reward: float,
        done: bool,
        value: float,
        logprob: float,
    ):
        """Add one transition."""
        if self.ptr >= self.size:
            raise RuntimeError("RolloutBuffer is full. Call reset() before adding more.")

        obs_arr = np.asarray(obs, dtype=self.obs.dtype)

        # Enforce shape to catch silent bugs early (common in notebooks)
        if obs_arr.shape != self.obs_shape:
            raise ValueError(
                f"Obs shape mismatch: got {obs_arr.shape}, expected {self.obs_shape}. "
                "Check env output (vector vs State1D) and model input."
            )

        self.obs[self.ptr] = obs_arr
        self.actions[self.ptr] = int(action)
        self.rewards[self.ptr] = float(reward)
        self.dones[self.ptr] = 1.0 if bool(done) else 0.0
        self.values[self.ptr] = float(value)
        self.logprobs[self.ptr] = float(logprob)

        self.ptr += 1
        self.full = (self.ptr == self.size)

    @torch.no_grad()
    def compute_gae(
        self,
        last_value: float,
        gamma: float = 0.99,
        lam: float = 0.95,
        normalize_adv: bool = True,
        eps: float = 1e-8,
    ):
        """
        Compute GAE(lambda) advantages and returns.

        last_value should be V(s_T) for the *last obs after the rollout*,
        used to bootstrap if the last transition was not terminal.
        """
        if not self.full:
            raise RuntimeError("RolloutBuffer not full. Collect 'size' steps before compute_gae().")

        last_value = float(last_value)

        adv = 0.0
        for t in reversed(range(self.size)):
            done = self.dones[t]
            mask = 1.0 - done  # 0 if terminal else 1

            next_value = last_value if (t == self.size - 1) else float(self.values[t + 1])
            delta = float(self.rewards[t]) + gamma * next_value * mask - float(self.values[t])

            adv = delta + gamma * lam * mask * adv
            self.advantages[t] = adv

        # returns target for critic
        self.returns = self.advantages + self.values

        if normalize_adv:
            m = float(self.advantages.mean())
            s = float(self.advantages.std())
            self.advantages = (self.advantages - m) / (s + eps)

    def get_batches(self, batch_size: int, shuffle: bool = True) -> Iterator[PPOBatch]:
        """Yield mini-batches as torch tensors."""
        if not self.full:
            raise RuntimeError("RolloutBuffer not full. Collect rollout before batching.")

        if batch_size <= 0:
            raise ValueError("batch_size must be > 0")

        idxs = np.arange(self.size)
        if shuffle:
            np.random.shuffle(idxs)

        for start in range(0, self.size, batch_size):
            b_idx = idxs[start : start + batch_size]

            # CPU tensors first (optionally pinned), then move to device
            obs_cpu = torch.from_numpy(self.obs[b_idx]).float()
            if self.pin_memory:
                obs_cpu = obs_cpu.pin_memory()
            obs_t = obs_cpu.to(self.device, non_blocking=self.pin_memory)

            actions_t = torch.from_numpy(self.actions[b_idx]).long().to(self.device)
            old_logp_t = torch.from_numpy(self.logprobs[b_idx]).float().to(self.device)
            adv_t = torch.from_numpy(self.advantages[b_idx]).float().to(self.device)
            ret_t = torch.from_numpy(self.returns[b_idx]).float().to(self.device)
            old_v_t = torch.from_numpy(self.values[b_idx]).float().to(self.device)

            yield PPOBatch(
                obs=obs_t,
                actions=actions_t,
                old_logprobs=old_logp_t,
                advantages=adv_t,
                returns=ret_t,
                old_values=old_v_t,
            )


print("✓ PPO buffer defined (updated/robust)")

✓ PPO buffer defined (updated/robust)


# Cell 4

In [42]:
from typing import Dict, Any
import numpy as np
import torch
from torch.distributions import Categorical

@torch.no_grad()
def policy_act(model, obs, device: torch.device, greedy: bool = False):
    """
    Sample (or greedy-select) action from current policy.
    Returns: action(int), logprob(float), value(float)
    Works for both:
      - vector obs: (obs_dim,)
      - State1D obs: (C, T)
    """
    obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
    logits, value = model(obs_t)
    dist = Categorical(logits=logits)

    if greedy:
        action_t = torch.argmax(logits, dim=1)
    else:
        action_t = dist.sample()

    logprob_t = dist.log_prob(action_t)
    return int(action_t.item()), float(logprob_t.item()), float(value.item())


def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
    """1 - Var[y_true - y_pred] / Var[y_true]. Diagnostic for critic fit."""
    var_y = float(np.var(y_true))
    if var_y < 1e-12:
        return 0.0
    return float(1.0 - np.var(y_true - y_pred) / (var_y + 1e-12))


@torch.no_grad()
def validation_run_ppo(env, model, episodes: int = 50, device="cpu", greedy: bool = True) -> Dict[str, float]:
    """
    Evaluate a PPO policy on the given environment.
    Returns mean metrics across episodes.

    IMPORTANT: This assumes the env executes actions at OPEN(t+1), consistent with your repo env.
    """
    device = torch.device(device) if not isinstance(device, torch.device) else device

    # Ensure deterministic eval behavior (dropout/bn etc.)
    was_training = model.training
    model.eval()

    base_env = env.unwrapped if hasattr(env, "unwrapped") else env
    st = getattr(base_env, "_state", None)
    if st is None:
        # restore mode before raising
        if was_training:
            model.train()
        raise AttributeError("validation_run_ppo expected env.unwrapped to have attribute '_state'.")

    stats = {
        "episode_reward": [],
        "episode_steps": [],
        "num_trades": [],
        "win_rate": [],
        "avg_trade_return": [],
        "avg_hold_steps": [],
        "sum_trade_return": [],
    }

    for _ in range(episodes):
        obs, info = env.reset()
        done = False
        total_reward = 0.0
        steps = 0

        # Manual trade tracking (OPEN(t+1) execution)
        in_pos = False
        entry_price = None
        hold_steps = 0
        trade_returns = []
        trade_hold_steps = []

        while not done:
            action, _lp, _v = policy_act(model, obs, device=device, greedy=greedy)

            # Safer enum conversion
            try:
                act_enum = Actions(action)
            except Exception:
                # If action is invalid, treat as "Skip" (safe fallback)
                act_enum = Actions.Skip

            # Access state after reset/step (base_env._state can change reference)
            st = base_env._state
            next_idx = st._offset + 1

            exec_open = None
            if 0 <= next_idx < st._prices.open.shape[0]:
                exec_open = float(st._prices.open[next_idx])

            prev_have_pos_env = bool(st.have_position)

            # Bookkeeping BEFORE env.step(), using OPEN(t+1)
            if exec_open is not None:
                if act_enum == Actions.Buy and not in_pos:
                    in_pos = True
                    entry_price = exec_open
                    hold_steps = 0
                elif act_enum == Actions.Close and in_pos:
                    if entry_price is not None and entry_price > 0:
                        tr = 100.0 * (exec_open - entry_price) / entry_price
                    else:
                        tr = 0.0
                    trade_returns.append(tr)
                    trade_hold_steps.append(hold_steps)
                    in_pos = False
                    entry_price = None
                    hold_steps = 0

            # Step environment
            obs, reward, terminated, truncated, info = env.step(action)
            done = bool(terminated or truncated)
            total_reward += float(reward)
            steps += 1

            # Detect forced close by env (stop-loss / done close, etc.)
            now_have_pos_env = bool(base_env._state.have_position)
            if in_pos and prev_have_pos_env and (not now_have_pos_env) and act_enum != Actions.Close:
                # Use exec_open if available (OPEN(t+1)), else fallback to current close
                exit_price = exec_open if exec_open is not None else float(base_env._state._cur_close())
                if entry_price is not None and entry_price > 0:
                    tr = 100.0 * (exit_price - entry_price) / entry_price
                else:
                    tr = 0.0
                trade_returns.append(tr)
                trade_hold_steps.append(hold_steps)
                in_pos = False
                entry_price = None
                hold_steps = 0

            if in_pos:
                hold_steps += 1

        # If episode ends while holding, close at last close for reporting
        if in_pos and entry_price is not None and entry_price > 0:
            last_close = float(base_env._state._cur_close())
            tr = 100.0 * (last_close - entry_price) / entry_price
            trade_returns.append(tr)
            trade_hold_steps.append(hold_steps)

        # Episode metrics
        stats["episode_reward"].append(total_reward)
        stats["episode_steps"].append(steps)

        n_trades = len(trade_returns)
        stats["num_trades"].append(float(n_trades))

        if n_trades > 0:
            wins = sum(1 for x in trade_returns if x > 0.0)
            stats["win_rate"].append(float(wins / n_trades))
            stats["avg_trade_return"].append(float(np.mean(trade_returns)))
            stats["avg_hold_steps"].append(float(np.mean(trade_hold_steps)))
            stats["sum_trade_return"].append(float(np.sum(trade_returns)))
        else:
            stats["win_rate"].append(0.0)
            stats["avg_trade_return"].append(0.0)
            stats["avg_hold_steps"].append(0.0)
            stats["sum_trade_return"].append(0.0)

    # Restore model mode
    if was_training:
        model.train()

    return {k: float(np.mean(v)) for k, v in stats.items()}


print("✓ Helper functions defined (updated/robust)")

✓ Helper functions defined (updated/robust)


# Cell 5

In [43]:
# ===== Training Configuration (matched to CLI + holding cost experiment) =====
config = {
    # Data
    "data_dir": "yf_data",
    "run_name": "ppo_aapl_final_holdpen_2e-4",  # NEW: identify holding-cost run
    "seed": 0,                                  # same as before

    # Device
    "use_cuda": True,                           # CLI: --cuda (CPU fallback is fine)

    # PPO Hyperparameters (UNCHANGED)
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_eps": 0.2,
    "lr": 3e-4,
    "rollout_steps": 1024,
    "minibatch": 256,
    "epochs": 5,
    "value_coef": 0.5,
    "entropy_coef": 0.01,
    "max_grad_norm": 0.5,
    "target_kl": 0.02,

    # Environment
    "bars": 10,
    "volumes": True,
    "extra_features": True,
    "reward_mode": "close_pnl",
    "state_1d": False,
    "time_limit": 1000,

    # >>> NEW: Holding-cost controls <<<
    "hold_penalty_per_step": 2e-4,   # 0.02 reward units per step (after *100 scaling)
    "max_hold_steps": None,          # set to e.g. 50–100 ONLY if needed later

    # Data Split
    "split": True,
    "train_ratio": 0.8,
    "min_train": 300,
    "min_val": 300,

    # Training Control
    "max_rollouts": 500,
    "total_steps": 10_000_000,

    # Validation & Checkpointing
    "val_every_rollouts": 10,
    "save_every_rollouts": 10,
    "early_stop": True,
    "patience": 20,
    "min_rollouts": 50,
    "min_delta": 0.01,
}

# Set device (matches --cuda behavior)
device = torch.device("cuda" if (config["use_cuda"] and torch.cuda.is_available()) else "cpu")
print(f"Using device: {device}")

# Set seed
set_seed(config["seed"])

print("✓ Configuration set (with holding cost)")

Using device: cpu
✓ Configuration set (with holding cost)


## 6. Load and Prepare Data

Load stock price data and split into train/validation sets.


In [44]:
# Load all price data
prices_all = load_many_from_dir(config["data_dir"])
print(f"Loaded {len(prices_all)} instruments: {list(prices_all.keys())}")

# Optional: if you want to exactly match an AAPL-only run
# TARGET = "AAPL"
# if TARGET in prices_all:
#     prices_all = {TARGET: prices_all[TARGET]}
#     print(f"Filtered to single instrument: {TARGET}")

# Split into train/validation (chronological, no leakage)
if config["split"]:
    prices_train, prices_val = split_many_by_ratio(
        prices_all,
        train_ratio=config["train_ratio"],
        min_train=config["min_train"],
        min_val=config["min_val"],
    )
    print(f"Train instruments: {len(prices_train)}")
    print(f"Validation instruments: {len(prices_val)}")
else:
    prices_train = prices_all
    prices_val = prices_all
    print("No split: using same data for train and validation (in-sample)")

# Sanity checks (strongly recommended)
for k in prices_train.keys():
    tr_len = len(prices_train[k].close)
    va_len = len(prices_val[k].close)
    if config["split"]:
        assert tr_len >= config["min_train"], f"{k}: train too short ({tr_len})"
        assert va_len >= config["min_val"], f"{k}: val too short ({va_len})"
    print(f"  {k}: train_len={tr_len}, val_len={va_len}")

print("✓ Data loaded and split")

Loaded 3 instruments: ['AAPL_1d_2020-01-01_to_2025-12-23', 'TSLA_1d_2020-01-01_to_2025-11-30', 'TSLA_1d_2020-01-01_to_2025-12-31']
Train instruments: 2
Validation instruments: 2
  AAPL_1d_2020-01-01_to_2025-12-23: train_len=1201, val_len=301
  TSLA_1d_2020-01-01_to_2025-11-30: train_len=1186, val_len=300
✓ Data loaded and split


## 7. Create Environments

Set up training and validation environments.


In [45]:
# Create training environment (match CLI behavior)
env_train_base = StocksEnv(
    prices_train,
    bars_count=config["bars"],
    volumes=config["volumes"],
    extra_features=config["extra_features"],
    reward_mode=config["reward_mode"],
    state_1d=config["state_1d"],
    hold_penalty_per_step=float(config["hold_penalty_per_step"]),
    max_hold_steps=config.get("max_hold_steps", None),
)
env_train = gym.wrappers.TimeLimit(env_train_base, max_episode_steps=config["time_limit"])

# Create validation environment (use same settings + same TimeLimit)
env_val_base = StocksEnv(
    prices_val,
    bars_count=config["bars"],
    volumes=config["volumes"],
    extra_features=config["extra_features"],
    reward_mode=config["reward_mode"],
    state_1d=config["state_1d"],
    hold_penalty_per_step=float(config["hold_penalty_per_step"]),
    max_hold_steps=config.get("max_hold_steps", None),
)
env_val = gym.wrappers.TimeLimit(env_val_base, max_episode_steps=config["time_limit"])

obs_shape = env_train.observation_space.shape
n_actions = env_train.action_space.n

print(f"Observation space: {obs_shape}")
print(f"Action space: {n_actions} actions")
print("✓ Environments created")

Observation space: (45,)
Action space: 3 actions
✓ Environments created


## 8. Initialize Model and Optimizer

Create the PPO model and optimizer.


In [46]:
# Cell 8: Build model + optimizer + TensorBoard writer (compatible with updated Cell 2)

# Infer observation shape from env
# - Vector state: obs_shape = (obs_dim,)
# - State1D:      obs_shape = (C, T)
if config["state_1d"]:
    C, T = obs_shape  # (channels, time bars)

    # Our updated ActorCriticConv1D computes in_channels internally from `volumes`
    # (C should equal (3 + (1 if volumes else 0)) + 2, usually 6 if volumes=True)
    expected_C = (3 + (1 if config["volumes"] else 0)) + 2
    if C != expected_C:
        raise ValueError(
            f"State1D channel mismatch: env returned C={C}, expected C={expected_C}. "
            "Check env volumes/state_1d settings."
        )

    model = ActorCriticConv1D(
        n_actions=n_actions,
        bars_count=T,
        volumes=config["volumes"],
        hidden=256,
    ).to(device)

    print(f"Created Conv1D model: C={C}, T={T}")
else:
    obs_dim = obs_shape[0]
    model = ActorCriticMLP(obs_dim=obs_dim, n_actions=n_actions, hidden=256).to(device)
    print(f"Created MLP model: obs_dim={obs_dim}")

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=config["lr"])

# TensorBoard writer (clean run directory like scripts usually do)
run_dir = os.path.join("runs", config["run_name"])
writer = SummaryWriter(log_dir=run_dir)
print(f"TensorBoard log dir: {run_dir}")
print("Model params:", sum(p.numel() for p in model.parameters()))

print("✓ Model and optimizer initialized")


Created MLP model: obs_dim=45
TensorBoard log dir: runs\ppo_aapl_final_holdpen_2e-4
Model params: 78596
✓ Model and optimizer initialized


## 9. Training Loop

Main PPO training loop: collect rollouts, compute advantages, update policy.


In [47]:
# ===== Training bookkeeping =====
obs, info = env_train.reset(seed=config["seed"])
episode_reward = 0.0
episode_steps = 0
episode_count = 0

global_step = 0
rollout_idx = 0
t0 = time.time()

best_val_reward = -1e9
no_improve = 0

print(f"[PPO] device={device} obs_shape={obs_shape} actions={n_actions}")
print(f"[PPO] reward_mode={config['reward_mode']} volumes={config['volumes']} extra_features={config['extra_features']}")
print(f"[PPO] split={config['split']} max_rollouts={config['max_rollouts']} early_stop={config['early_stop']}")
print(f"[PPO] logs: runs/  checkpoints: saves/")
print("=" * 60)

# Main training loop
while (global_step < config["total_steps"]) and (rollout_idx < config["max_rollouts"]):
    rollout_idx += 1
    buf = RolloutBuffer(obs_shape=obs_shape, size=config["rollout_steps"], device=device)

    # ===== Collect Rollout =====
    for _ in range(config["rollout_steps"]):
        global_step += 1

        action, logprob, value = policy_act(model, obs, device=device, greedy=False)
        next_obs, reward, terminated, truncated, info = env_train.step(action)

        episode_done = bool(terminated or truncated)

        # IMPORTANT FIX: mask GAE on ANY episode end (terminated OR truncated)
        buf_done = episode_done

        buf.add(
            obs=obs,
            action=action,
            reward=float(reward),
            done=buf_done,
            value=value,
            logprob=logprob,
        )

        episode_reward += float(reward)
        episode_steps += 1
        obs = next_obs

        if episode_done:
            episode_count += 1
            writer.add_scalar("train/episode_reward", episode_reward, global_step)
            writer.add_scalar("train/episode_steps", episode_steps, global_step)

            obs, info = env_train.reset()
            episode_reward = 0.0
            episode_steps = 0

        if global_step >= config["total_steps"]:
            break

    # If we didn't fill the buffer (e.g., hit total_steps), stop cleanly
    if not buf.full:
        print(f"[PPO] stopping: reached total_steps mid-rollout (filled {buf.ptr}/{buf.size}).")
        break

    # Rollout reward heartbeat
    roll_sum = float(buf.rewards.sum())
    roll_mean = float(buf.rewards.mean())
    writer.add_scalar("train/rollout_reward_sum", roll_sum, global_step)
    writer.add_scalar("train/rollout_reward_mean", roll_mean, global_step)

    # Bootstrap last value for GAE (only used if last transition not terminal)
    with torch.no_grad():
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
        _, last_v = model(obs_t)
        last_value = float(last_v.item())

    buf.compute_gae(
        last_value=last_value,
        gamma=config["gamma"],
        lam=config["gae_lambda"],
        normalize_adv=True,
    )

    # ===== PPO Update =====
    policy_losses = []
    value_losses = []
    entropies = []
    approx_kls = []
    clipfracs = []

    for _epoch in range(config["epochs"]):
        for batch in buf.get_batches(batch_size=config["minibatch"], shuffle=True):
            logits, values = model(batch.obs)
            dist = Categorical(logits=logits)

            new_logp = dist.log_prob(batch.actions)
            entropy = dist.entropy().mean()

            # PPO ratio
            ratio = torch.exp(new_logp - batch.old_logprobs)

            # Policy loss (clipped surrogate)
            unclipped = ratio * batch.advantages
            clipped = torch.clamp(ratio, 1.0 - config["clip_eps"], 1.0 + config["clip_eps"]) * batch.advantages
            loss_pi = -torch.min(unclipped, clipped).mean()

            # Value loss
            loss_v = (batch.returns - values).pow(2).mean()

            # Total loss
            loss = loss_pi + config["value_coef"] * loss_v - config["entropy_coef"] * entropy

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
            optimizer.step()

            with torch.no_grad():
                # Better PPO-style KL approximation (no abs)
                approx_kl = float((batch.old_logprobs - new_logp).mean().item())
                clipfrac = float((torch.abs(ratio - 1.0) > config["clip_eps"]).float().mean().item())

            policy_losses.append(float(loss_pi.item()))
            value_losses.append(float(loss_v.item()))
            entropies.append(float(entropy.item()))
            approx_kls.append(approx_kl)
            clipfracs.append(clipfrac)

        # Early stop PPO epoch loop if KL too big
        if config["target_kl"] > 0 and approx_kls and (np.mean(approx_kls) > config["target_kl"]):
            break

    # ===== Logging =====
    fps = global_step / max(1e-9, (time.time() - t0))
    writer.add_scalar("ppo/policy_loss", float(np.mean(policy_losses) if policy_losses else 0.0), global_step)
    writer.add_scalar("ppo/value_loss", float(np.mean(value_losses) if value_losses else 0.0), global_step)
    writer.add_scalar("ppo/entropy", float(np.mean(entropies) if entropies else 0.0), global_step)
    writer.add_scalar("ppo/approx_kl", float(np.mean(approx_kls) if approx_kls else 0.0), global_step)
    writer.add_scalar("ppo/clipfrac", float(np.mean(clipfracs) if clipfracs else 0.0), global_step)
    writer.add_scalar("train/fps", float(fps), global_step)
    writer.add_scalar("train/episodes", float(episode_count), global_step)

    # Value function explained variance
    ev = explained_variance(np.asarray(buf.values), np.asarray(buf.returns))
    writer.add_scalar("ppo/explained_variance", ev, global_step)

    # Console heartbeat
    print(
        f"[rollout {rollout_idx}] step={global_step} "
        f"roll_sum={roll_sum:.3f} roll_mean={roll_mean:.5f} "
        f"pi_loss={np.mean(policy_losses) if policy_losses else 0.0:.4f} "
        f"v_loss={np.mean(value_losses) if value_losses else 0.0:.4f} "
        f"ent={np.mean(entropies) if entropies else 0.0:.3f} "
        f"kl={np.mean(approx_kls) if approx_kls else 0.0:.4f} "
        f"clip={np.mean(clipfracs) if clipfracs else 0.0:.3f} "
        f"eps_done={episode_count} fps={fps:.1f}"
    )

    # ===== Validation + Best Model Saving + Early Stop =====
    if config["val_every_rollouts"] > 0 and (rollout_idx % config["val_every_rollouts"] == 0):
        val = validation_run_ppo(env_val, model, episodes=20, device=device, greedy=True)

        for k, v in val.items():
            writer.add_scalar("val/" + k, v, global_step)

        print(f"  [val] {val}")

        cur_val = float(val.get("episode_reward", -1e9))
        if cur_val > best_val_reward + config["min_delta"]:
            best_val_reward = cur_val
            no_improve = 0
            best_path = os.path.join("saves", f"ppo_{config['run_name']}_best.pt")
            torch.save(model.state_dict(), best_path)
            print(f"  [best] new best val episode_reward={best_val_reward:.4f} -> {best_path}")
        else:
            no_improve += 1

        if config["early_stop"] and (rollout_idx >= config["min_rollouts"]) and (no_improve >= config["patience"]):
            print(f"[PPO] early stopping: no val improvement for {no_improve} validations.")
            break

    # ===== Save Periodic Checkpoint =====
    if config["save_every_rollouts"] > 0 and (rollout_idx % config["save_every_rollouts"] == 0):
        ckpt_path = os.path.join("saves", f"ppo_{config['run_name']}_rollout{rollout_idx}.pt")
        torch.save(model.state_dict(), ckpt_path)
        print(f"  [save] {ckpt_path}")

    # Hard stop condition
    if config["max_rollouts"] and rollout_idx >= config["max_rollouts"]:
        print("[PPO] reached max_rollouts, stopping.")
        break

if global_step >= config["total_steps"]:
    print(f"[done] reached total_steps={config['total_steps']} at rollout={rollout_idx}")
elif rollout_idx >= config["max_rollouts"]:
    print(f"[done] reached max_rollouts={config['max_rollouts']} at step={global_step}")

writer.close()
print("\n✓ Training completed!")

[PPO] device=cpu obs_shape=(45,) actions=3
[PPO] reward_mode=close_pnl volumes=True extra_features=True
[PPO] split=True max_rollouts=500 early_stop=True
[PPO] logs: runs/  checkpoints: saves/
[rollout 1] step=1024 roll_sum=48.222 roll_mean=0.04709 pi_loss=-0.0002 v_loss=59.3379 ent=1.099 kl=0.0000 clip=0.000 eps_done=1 fps=934.6
[rollout 2] step=2048 roll_sum=-2.165 roll_mean=-0.00211 pi_loss=-0.0004 v_loss=29.6044 ent=1.099 kl=-0.0001 clip=0.000 eps_done=2 fps=899.3
[rollout 3] step=3072 roll_sum=34.857 roll_mean=0.03404 pi_loss=-0.0003 v_loss=13.3122 ent=1.098 kl=-0.0001 clip=0.000 eps_done=3 fps=882.7
[rollout 4] step=4096 roll_sum=-52.316 roll_mean=-0.05109 pi_loss=-0.0000 v_loss=35.8102 ent=1.098 kl=0.0003 clip=0.000 eps_done=6 fps=876.7
[rollout 5] step=5120 roll_sum=68.676 roll_mean=0.06707 pi_loss=-0.0004 v_loss=55.8822 ent=1.098 kl=0.0001 clip=0.000 eps_done=8 fps=884.7
[rollout 6] step=6144 roll_sum=97.514 roll_mean=0.09523 pi_loss=0.0001 v_loss=78.2339 ent=1.098 kl=0.0002 c

# Cell 10: Load best model + final validation

In [49]:
best_path = os.path.join("saves", f"ppo_{config['run_name']}_best.pt")
if os.path.exists(best_path):
    state = torch.load(best_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    print(f"Loaded best model from: {best_path}")

    # Evaluate on validation set
    final_val = validation_run_ppo(env_val, model, episodes=100, device=device, greedy=True)

    print("\n" + "=" * 60)
    print("Final Validation Results (100 episodes):")
    print("=" * 60)
    for k, v in final_val.items():
        print(f"  {k}: {v:.4f}")
    print("=" * 60)
else:
    print(f"Best model checkpoint not found: {best_path}")

Loaded best model from: saves\ppo_ppo_aapl_final_holdpen_2e-4_best.pt

Final Validation Results (100 episodes):
  episode_reward: 30.8059
  episode_steps: 187.9800
  num_trades: 1.0000
  win_rate: 0.9600
  avg_trade_return: 30.4833
  avg_hold_steps: 186.9800
  sum_trade_return: 30.4833


## Notes

- **TensorBoard**: View training metrics with `tensorboard --logdir runs/`
- **Checkpoints**: Best model saved to `saves/ppo_{run_name}_best.pt`
- **Configuration**: Adjust hyperparameters in the Configuration cell (Section 5)
- **Early Stopping**: Enabled by default based on validation performance
- **Data Split**: Chronological split prevents data leakage (recommended)
