# PPO Training for Stock Trading

This notebook implements Proximal Policy Optimization (PPO) for training a reinforcement learning agent to trade stocks.

## Overview
- **Environment**: Stock trading environment using historical price data
- **Algorithm**: PPO (Proximal Policy Optimization)
- **Model**: Actor-Critic network (MLP or CNN)
- **Features**: Volume, extra features (volatility, ATR-like), chronological train/val split


## 1. Imports and Setup


In [13]:
import os
import time

import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from tensorboardX import SummaryWriter

# Import our stock trading environment package
from stock_trading_env import (
    StocksEnv,
    load_many_from_dir,
    split_many_by_ratio,
    Actions,
)

# Set random seeds for reproducibility
def set_seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# Create output directories
os.makedirs("runs", exist_ok=True)
os.makedirs("saves", exist_ok=True)
os.makedirs("logs", exist_ok=True)

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

✓ Imports successful
PyTorch version: 2.9.0+cpu
CUDA available: False


## 2. PPO Model Definitions

Define the Actor-Critic networks for PPO.


In [14]:
class ActorCriticMLP(nn.Module):
    """
    PPO Actor-Critic for DISCRETE actions (MLP version).
    
    Input:  (B, obs_dim)
    Output: logits (B, n_actions), value (B,)
      - logits are unnormalized scores for a Categorical distribution
      - value is V(s) baseline for advantage estimation
    """
    
    def __init__(self, obs_dim: int, n_actions: int, hidden: int = 256):
        super().__init__()
        
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        
        self.policy_head = nn.Linear(hidden, n_actions)  # actor logits
        self.value_head = nn.Linear(hidden, 1)           # critic value
        
        self._init_weights()
    
    def _init_weights(self):
        """PPO is usually more stable with small initial weights."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1.0)
                nn.init.constant_(m.bias, 0.0)
        
        # Slightly smaller init for final policy head can help early training
        nn.init.orthogonal_(self.policy_head.weight, gain=0.01)
        nn.init.constant_(self.policy_head.bias, 0.0)
    
    def forward(self, x: torch.Tensor):
        """x: float tensor (B, obs_dim), returns: logits (B, n_actions), value (B,)"""
        z = self.shared(x)
        logits = self.policy_head(z)
        value = self.value_head(z).squeeze(-1)
        return logits, value


class ActorCriticConv1D(nn.Module):
    """
    PPO Actor-Critic for State1D observations (CNN version).
    
    Input:  (B, C, T)
    Output: logits (B, n_actions), value (B,)
    """
    
    def __init__(self, in_channels: int, n_actions: int, bars_count: int, hidden: int = 256):
        super().__init__()
        
        # A small 1D CNN feature extractor
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        
        # Compute conv output size = 64 * bars_count
        conv_out = 64 * bars_count
        
        self.shared = nn.Sequential(
            nn.Linear(conv_out, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        
        self.policy_head = nn.Linear(hidden, n_actions)
        self.value_head = nn.Linear(hidden, 1)
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv1d)):
                nn.init.orthogonal_(m.weight, gain=1.0)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
        
        nn.init.orthogonal_(self.policy_head.weight, gain=0.01)
        nn.init.constant_(self.policy_head.bias, 0.0)
    
    def forward(self, x: torch.Tensor):
        """x: float tensor (B, C, T), returns: logits (B, n_actions), value (B,)"""
        feat = self.conv(x)
        z = self.shared(feat)
        logits = self.policy_head(z)
        value = self.value_head(z).squeeze(-1)
        return logits, value

print("✓ PPO models defined")

✓ PPO models defined


## 3. PPO Buffer and Utilities

Define the rollout buffer for collecting experience and computing GAE advantages.


In [15]:
from dataclasses import dataclass
from typing import Iterator


@dataclass
class PPOBatch:
    """One minibatch used during PPO updates."""
    obs: torch.Tensor
    actions: torch.Tensor
    old_logprobs: torch.Tensor
    advantages: torch.Tensor
    returns: torch.Tensor
    old_values: torch.Tensor


class RolloutBuffer:
    """
    Stores a fixed-length rollout of experience and computes:
      - GAE(lambda) advantages
      - returns (targets for value function)
    """
    
    def __init__(self, obs_shape, size: int, device="cpu", dtype=np.float32):
        if isinstance(obs_shape, int):
            obs_shape = (obs_shape,)
        self.obs_shape = tuple(obs_shape)
        self.size = int(size)
        self.device = device
        
        # Storage
        self.obs = np.zeros((self.size, *self.obs_shape), dtype=dtype)
        self.actions = np.zeros((self.size,), dtype=np.int64)
        self.rewards = np.zeros((self.size,), dtype=np.float32)
        self.dones = np.zeros((self.size,), dtype=np.float32)
        self.values = np.zeros((self.size,), dtype=np.float32)
        self.logprobs = np.zeros((self.size,), dtype=np.float32)
        
        # Computed after rollout
        self.advantages = np.zeros((self.size,), dtype=np.float32)
        self.returns = np.zeros((self.size,), dtype=np.float32)
        
        self.ptr = 0
        self.full = False
    
    def reset(self):
        self.ptr = 0
        self.full = False
    
    def add(self, obs: np.ndarray, action: int, reward: float, done: bool, 
            value: float, logprob: float):
        """Add one transition."""
        if self.ptr >= self.size:
            raise RuntimeError("RolloutBuffer is full. Call reset() before adding more.")
        
        self.obs[self.ptr] = np.asarray(obs, dtype=self.obs.dtype)
        self.actions[self.ptr] = int(action)
        self.rewards[self.ptr] = float(reward)
        self.dones[self.ptr] = 1.0 if bool(done) else 0.0
        self.values[self.ptr] = float(value)
        self.logprobs[self.ptr] = float(logprob)
        
        self.ptr += 1
        if self.ptr == self.size:
            self.full = True
    
    def compute_gae(self, last_value: float, gamma: float = 0.99, 
                    lam: float = 0.95, normalize_adv: bool = True):
        """Compute advantages and returns using GAE(lambda)."""
        if not self.full:
            raise RuntimeError("RolloutBuffer not full. Collect 'size' steps before compute_gae().")
        
        adv = 0.0
        for t in reversed(range(self.size)):
            mask = 1.0 - self.dones[t]  # 0 if terminal else 1
            next_value = last_value if t == self.size - 1 else self.values[t + 1]
            delta = self.rewards[t] + gamma * next_value * mask - self.values[t]
            adv = delta + gamma * lam * mask * adv
            self.advantages[t] = adv
        
        self.returns = self.advantages + self.values
        
        if normalize_adv:
            m = float(self.advantages.mean())
            s = float(self.advantages.std()) + 1e-8
            self.advantages = (self.advantages - m) / s
    
    def get_batches(self, batch_size: int, shuffle: bool = True) -> Iterator[PPOBatch]:
        """Yield mini-batches as torch tensors."""
        if not self.full:
            raise RuntimeError("RolloutBuffer not full. Collect rollout before batching.")
        
        idxs = np.arange(self.size)
        if shuffle:
            np.random.shuffle(idxs)
        
        for start in range(0, self.size, batch_size):
            b_idx = idxs[start:start + batch_size]
            
            obs_t = torch.as_tensor(self.obs[b_idx], device=self.device, dtype=torch.float32)
            actions_t = torch.as_tensor(self.actions[b_idx], device=self.device, dtype=torch.long)
            old_logp_t = torch.as_tensor(self.logprobs[b_idx], device=self.device, dtype=torch.float32)
            adv_t = torch.as_tensor(self.advantages[b_idx], device=self.device, dtype=torch.float32)
            ret_t = torch.as_tensor(self.returns[b_idx], device=self.device, dtype=torch.float32)
            old_v_t = torch.as_tensor(self.values[b_idx], device=self.device, dtype=torch.float32)
            
            yield PPOBatch(
                obs=obs_t,
                actions=actions_t,
                old_logprobs=old_logp_t,
                advantages=adv_t,
                returns=ret_t,
                old_values=old_v_t,
            )

print("✓ PPO buffer defined")

✓ PPO buffer defined


In [None]:
@torch.no_grad()
def policy_act(model, obs, device: torch.device, greedy: bool = False):
    """
    Sample (or greedy-select) action from current policy.
    Returns: action(int), logprob(float), value(float)
    """
    obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
    logits, value = model(obs_t)
    dist = Categorical(logits=logits)
    
    if greedy:
        action_t = torch.argmax(logits, dim=1)
    else:
        action_t = dist.sample()
    
    logprob_t = dist.log_prob(action_t)
    return int(action_t.item()), float(logprob_t.item()), float(value.item())


def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
    """1 - Var[y_true - y_pred] / Var[y_true]. Diagnostic: is the critic predicting returns well?"""
    var_y = np.var(y_true)
    if var_y < 1e-12:
        return 0.0
    return float(1.0 - np.var(y_true - y_pred) / (var_y + 1e-12))


@torch.no_grad()
def validation_run_ppo(env, model, episodes: int = 50, device="cpu", greedy: bool = True):
    """
    Evaluate a PPO policy on the given environment.
    Returns mean metrics across episodes.
    """
    device = torch.device(device) if not isinstance(device, torch.device) else device
    base_env = env.unwrapped if hasattr(env, "unwrapped") else env
    
    stats = {
        "episode_reward": [],
        "episode_steps": [],
        "num_trades": [],
        "win_rate": [],
        "avg_trade_return": [],
        "avg_hold_steps": [],
        "sum_trade_return": [],
    }
    
    for _ in range(episodes):
        obs, info = env.reset()
        done = False
        total_reward = 0.0
        steps = 0
        
        # Manual trade tracking
        in_pos = False
        entry_price = None
        hold_steps = 0
        trade_returns = []
        trade_hold_steps = []
        
        while not done:
            obs_v = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            logits, _value = model(obs_v)
            
            if greedy:
                action = int(torch.argmax(logits, dim=1).item())
            else:
                dist = Categorical(logits=logits)
                action = int(dist.sample().item())
            
            act_enum = Actions(action)
            st = base_env._state
            next_idx = st._offset + 1
            
            exec_open = None
            if 0 <= next_idx < st._prices.open.shape[0]:
                exec_open = float(st._prices.open[next_idx])
            
            prev_have_pos_env = bool(st.have_position)
            
            # Bookkeeping BEFORE step
            if exec_open is not None:
                if act_enum == Actions.Buy and not in_pos:
                    in_pos = True
                    entry_price = exec_open
                    hold_steps = 0
                elif act_enum == Actions.Close and in_pos:
                    if entry_price and entry_price > 0:
                        tr = 100.0 * (exec_open - entry_price) / entry_price
                    else:
                        tr = 0.0
                    trade_returns.append(tr)
                    trade_hold_steps.append(hold_steps)
                    in_pos = False
                    entry_price = None
                    hold_steps = 0
            
            # Step environment
            obs, reward, terminated, truncated, info = env.step(action)
            done = bool(terminated or truncated)
            total_reward += float(reward)
            steps += 1
            
            # Detect forced close
            now_have_pos_env = bool(base_env._state.have_position)
            if in_pos and prev_have_pos_env and (not now_have_pos_env) and act_enum != Actions.Close:
                exit_price = exec_open if exec_open is not None else float(base_env._state._cur_close())
                if entry_price and entry_price > 0:
                    tr = 100.0 * (exit_price - entry_price) / entry_price
                else:
                    tr = 0.0
                trade_returns.append(tr)
                trade_hold_steps.append(hold_steps)
                in_pos = False
                entry_price = None
                hold_steps = 0
            
            if in_pos:
                hold_steps += 1
        
        # If episode ends while holding
        if in_pos and entry_price and entry_price > 0:
            last_close = float(base_env._state._cur_close())
            tr = 100.0 * (last_close - entry_price) / entry_price
            trade_returns.append(tr)
            trade_hold_steps.append(hold_steps)
        
        # Episode metrics
        stats["episode_reward"].append(total_reward)
        stats["episode_steps"].append(steps)
        n_trades = len(trade_returns)
        stats["num_trades"].append(float(n_trades))
        
        if n_trades > 0:
            wins = sum(1 for x in trade_returns if x > 0.0)
            stats["win_rate"].append(float(wins / n_trades))
            stats["avg_trade_return"].append(float(np.mean(trade_returns)))
            stats["avg_hold_steps"].append(float(np.mean(trade_hold_steps)))
            stats["sum_trade_return"].append(float(np.sum(trade_returns)))
        else:
            stats["win_rate"].append(0.0)
            stats["avg_trade_return"].append(0.0)
            stats["avg_hold_steps"].append(0.0)
            stats["sum_trade_return"].append(0.0)
    
    return {k: float(np.mean(v)) for k, v in stats.items()}

print("✓ Helper functions defined")

In [None]:
# ===== Training Configuration =====
config = {
    # Data
    "data_dir": "yf_data",
    "run_name": "ppo_training",
    "seed": 42,
    
    # Device
    "use_cuda": True,  # Set to False to force CPU
    
    # PPO Hyperparameters
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_eps": 0.2,
    "lr": 3e-4,
    "rollout_steps": 1024,
    "minibatch": 256,
    "epochs": 5,
    "value_coef": 0.5,
    "entropy_coef": 0.01,
    "max_grad_norm": 0.5,
    "target_kl": 0.02,
    
    # Environment
    "bars": 10,
    "volumes": True,
    "extra_features": True,
    "reward_mode": "close_pnl",  # or "step_logret"
    "state_1d": False,  # True for CNN, False for MLP
    "time_limit": 1000,
    
    # Data Split
    "split": True,  # Chronological train/val split (recommended)
    "train_ratio": 0.8,
    "min_train": 200,
    "min_val": 200,
    
    # Training Control
    "max_rollouts": 500,
    "total_steps": 10_000_000,
    
    # Validation & Checkpointing
    "val_every_rollouts": 10,
    "save_every_rollouts": 10,
    "early_stop": True,
    "patience": 20,
    "min_rollouts": 50,
    "min_delta": 1e-3,
}

# Set device
device = torch.device("cuda" if (config["use_cuda"] and torch.cuda.is_available()) else "cpu")
print(f"Using device: {device}")

# Set seed
set_seed(config["seed"])

print("✓ Configuration set")

## 6. Load and Prepare Data

Load stock price data and split into train/validation sets.


In [None]:
# Load all price data
prices_all = load_many_from_dir(config["data_dir"])
print(f"Loaded {len(prices_all)} instruments: {list(prices_all.keys())}")

# Split into train/validation (chronological, no leakage)
if config["split"]:
    prices_train, prices_val = split_many_by_ratio(
        prices_all,
        train_ratio=config["train_ratio"],
        min_train=config["min_train"],
        min_val=config["min_val"],
    )
    print(f"Train instruments: {len(prices_train)}")
    print(f"Validation instruments: {len(prices_val)}")
else:
    prices_train = prices_all
    prices_val = prices_all
    print("No split: using same data for train and validation (in-sample)")

print("✓ Data loaded and split")


## 7. Create Environments

Set up training and validation environments.


In [None]:
# Create training environment
env_train_base = StocksEnv(
    prices_train,
    bars_count=config["bars"],
    volumes=config["volumes"],
    extra_features=config["extra_features"],
    reset_on_close=False,
    reward_on_close=False,
    reward_mode=config["reward_mode"],
    state_1d=config["state_1d"],
)
env_train = gym.wrappers.TimeLimit(env_train_base, max_episode_steps=config["time_limit"])

# Create validation environment
env_val = StocksEnv(
    prices_val,
    bars_count=config["bars"],
    volumes=config["volumes"],
    extra_features=config["extra_features"],
    reset_on_close=False,
    reward_on_close=False,
    reward_mode=config["reward_mode"],
    state_1d=config["state_1d"],
)

obs_shape = env_train.observation_space.shape
n_actions = env_train.action_space.n

print(f"Observation space: {obs_shape}")
print(f"Action space: {n_actions} actions")
print("✓ Environments created")

## 8. Initialize Model and Optimizer

Create the PPO model and optimizer.


In [None]:
# Build model
if config["state_1d"]:
    C, T = obs_shape  # (channels, time)
    model = ActorCriticConv1D(in_channels=C, n_actions=n_actions, bars_count=T).to(device)
    print(f"Created Conv1D model: channels={C}, bars={T}")
else:
    obs_dim = obs_shape[0]
    model = ActorCriticMLP(obs_dim=obs_dim, n_actions=n_actions).to(device)
    print(f"Created MLP model: obs_dim={obs_dim}")

# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=config["lr"])

# Create TensorBoard writer
writer = SummaryWriter(comment=f"-ppo-{config['run_name']}")

print("✓ Model and optimizer initialized")

## 9. Training Loop

Main PPO training loop: collect rollouts, compute advantages, update policy.


In [None]:
# Training bookkeeping
obs, info = env_train.reset(seed=config["seed"])
episode_reward = 0.0
episode_steps = 0
episode_count = 0

global_step = 0
rollout_idx = 0
t0 = time.time()

best_val_reward = -1e9
no_improve = 0

print(f"[PPO] device={device} obs_shape={obs_shape} actions={n_actions}")
print(f"[PPO] reward_mode={config['reward_mode']} volumes={config['volumes']} extra_features={config['extra_features']}")
print(f"[PPO] split={config['split']} max_rollouts={config['max_rollouts']} early_stop={config['early_stop']}")
print(f"[PPO] logs: runs/  checkpoints: saves/")
print("=" * 60)

# Main training loop
while (global_step < config["total_steps"]) and (rollout_idx < config["max_rollouts"]):
    rollout_idx += 1
    buf = RolloutBuffer(obs_shape=obs_shape, size=config["rollout_steps"], device=device)
    
    # ===== Collect Rollout =====
    for _ in range(config["rollout_steps"]):
        global_step += 1
        
        action, logprob, value = policy_act(model, obs, device=device, greedy=False)
        next_obs, reward, terminated, truncated, info = env_train.step(action)
        
        episode_done = bool(terminated or truncated)
        buf_done = bool(terminated)  # Only terminated (not truncated) for GAE
        
        buf.add(
            obs=obs,
            action=action,
            reward=float(reward),
            done=buf_done,
            value=value,
            logprob=logprob,
        )
        
        episode_reward += float(reward)
        episode_steps += 1
        obs = next_obs
        
        if episode_done:
            episode_count += 1
            writer.add_scalar("train/episode_reward", episode_reward, global_step)
            writer.add_scalar("train/episode_steps", episode_steps, global_step)
            
            obs, info = env_train.reset()
            episode_reward = 0.0
            episode_steps = 0
        
        if global_step >= config["total_steps"]:
            break
    
    # Rollout reward heartbeat
    roll_sum = float(buf.rewards.sum())
    roll_mean = float(buf.rewards.mean())
    writer.add_scalar("train/rollout_reward_sum", roll_sum, global_step)
    writer.add_scalar("train/rollout_reward_mean", roll_mean, global_step)
    
    # Bootstrap last value for GAE
    with torch.no_grad():
        obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
        _, last_v = model(obs_t)
        last_value = float(last_v.item())
    
    buf.compute_gae(
        last_value=last_value,
        gamma=config["gamma"],
        lam=config["gae_lambda"],
        normalize_adv=True
    )
    
    # ===== PPO Update =====
    policy_losses = []
    value_losses = []
    entropies = []
    approx_kls = []
    clipfracs = []
    
    for _epoch in range(config["epochs"]):
        for batch in buf.get_batches(batch_size=config["minibatch"], shuffle=True):
            logits, values = model(batch.obs)
            dist = Categorical(logits=logits)
            
            new_logp = dist.log_prob(batch.actions)
            entropy = dist.entropy().mean()
            
            # PPO ratio
            ratio = torch.exp(new_logp - batch.old_logprobs)
            
            # Policy loss (clipped surrogate)
            unclipped = ratio * batch.advantages
            clipped = torch.clamp(ratio, 1.0 - config["clip_eps"], 1.0 + config["clip_eps"]) * batch.advantages
            loss_pi = -torch.min(unclipped, clipped).mean()
            
            # Value loss
            loss_v = (batch.returns - values).pow(2).mean()
            
            # Total loss
            loss = loss_pi + config["value_coef"] * loss_v - config["entropy_coef"] * entropy
            
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
            optimizer.step()
            
            with torch.no_grad():
                approx_kl = float((batch.old_logprobs - new_logp).mean().abs().item())
                clipfrac = float((torch.abs(ratio - 1.0) > config["clip_eps"]).float().mean().item())
            
            policy_losses.append(float(loss_pi.item()))
            value_losses.append(float(loss_v.item()))
            entropies.append(float(entropy.item()))
            approx_kls.append(approx_kl)
            clipfracs.append(clipfrac)
        
        # Early stop PPO epoch loop if KL too big
        if config["target_kl"] > 0 and approx_kls and (np.mean(approx_kls) > config["target_kl"]):
            break
    
    # ===== Logging =====
    fps = global_step / max(1e-9, (time.time() - t0))
    writer.add_scalar("ppo/policy_loss", float(np.mean(policy_losses) if policy_losses else 0.0), global_step)
    writer.add_scalar("ppo/value_loss", float(np.mean(value_losses) if value_losses else 0.0), global_step)
    writer.add_scalar("ppo/entropy", float(np.mean(entropies) if entropies else 0.0), global_step)
    writer.add_scalar("ppo/approx_kl", float(np.mean(approx_kls) if approx_kls else 0.0), global_step)
    writer.add_scalar("ppo/clipfrac", float(np.mean(clipfracs) if clipfracs else 0.0), global_step)
    writer.add_scalar("train/fps", float(fps), global_step)
    writer.add_scalar("train/episodes", float(episode_count), global_step)
    
    # Value function explained variance
    ev = explained_variance(np.asarray(buf.values), np.asarray(buf.returns))
    writer.add_scalar("ppo/explained_variance", ev, global_step)
    
    # Console heartbeat
    print(
        f"[rollout {rollout_idx}] step={global_step} "
        f"roll_sum={roll_sum:.3f} roll_mean={roll_mean:.5f} "
        f"pi_loss={np.mean(policy_losses) if policy_losses else 0.0:.4f} "
        f"v_loss={np.mean(value_losses) if value_losses else 0.0:.4f} "
        f"ent={np.mean(entropies) if entropies else 0.0:.3f} "
        f"kl={np.mean(approx_kls) if approx_kls else 0.0:.4f} "
        f"clip={np.mean(clipfracs) if clipfracs else 0.0:.3f} "
        f"eps_done={episode_count} fps={fps:.1f}"
    )
    
    # ===== Validation + Best Model Saving + Early Stop =====
    if config["val_every_rollouts"] > 0 and (rollout_idx % config["val_every_rollouts"] == 0):
        model.eval()
        val = validation_run_ppo(env_val, model, episodes=20, device=device, greedy=True)
        model.train()
        
        for k, v in val.items():
            writer.add_scalar("val/" + k, v, global_step)
        
        print(f"  [val] {val}")
        
        cur_val = float(val.get("episode_reward", -1e9))
        if cur_val > best_val_reward + config["min_delta"]:
            best_val_reward = cur_val
            no_improve = 0
            best_path = os.path.join("saves", f"ppo_{config['run_name']}_best.pt")
            torch.save(model.state_dict(), best_path)
            print(f"  [best] new best val episode_reward={best_val_reward:.4f} -> {best_path}")
        else:
            no_improve += 1
        
        if config["early_stop"] and (rollout_idx >= config["min_rollouts"]) and (no_improve >= config["patience"]):
            print(f"[PPO] early stopping: no val improvement for {no_improve} validations.")
            break
    
    # ===== Save Periodic Checkpoint =====
    if config["save_every_rollouts"] > 0 and (rollout_idx % config["save_every_rollouts"] == 0):
        ckpt_path = os.path.join("saves", f"ppo_{config['run_name']}_rollout{rollout_idx}.pt")
        torch.save(model.state_dict(), ckpt_path)
        print(f"  [save] {ckpt_path}")
    
    # Hard stop condition
    if config["max_rollouts"] and rollout_idx >= config["max_rollouts"]:
        print("[PPO] reached max_rollouts, stopping.")
        break

if global_step >= config["total_steps"]:
    print(f"[done] reached total_steps={config['total_steps']} at rollout={rollout_idx}")
elif rollout_idx >= config["max_rollouts"]:
    print(f"[done] reached max_rollouts={config['max_rollouts']} at step={global_step}")

writer.close()
print("\n✓ Training completed!")

In [None]:
# Load best model
best_path = os.path.join("saves", f"ppo_{config['run_name']}_best.pt")
if os.path.exists(best_path):
    model.load_state_dict(torch.load(best_path, map_location=device))
    print(f"Loaded best model from: {best_path}")
    
    # Evaluate on validation set
    model.eval()
    final_val = validation_run_ppo(env_val, model, episodes=100, device=device, greedy=True)
    model.train()
    
    print("\n" + "=" * 60)
    print("Final Validation Results (100 episodes):")
    print("=" * 60)
    for k, v in final_val.items():
        print(f"  {k}: {v:.4f}")
    print("=" * 60)
else:
    print(f"Best model checkpoint not found: {best_path}")

## Notes

- **TensorBoard**: View training metrics with `tensorboard --logdir runs/`
- **Checkpoints**: Best model saved to `saves/ppo_{run_name}_best.pt`
- **Configuration**: Adjust hyperparameters in the Configuration cell (Section 5)
- **Early Stopping**: Enabled by default based on validation performance
- **Data Split**: Chronological split prevents data leakage (recommended)
