Import Required Packages

In [19]:
from datetime import datetime, timedelta
from loguru import logger
import MetaTrader5 as mt5
import pandas as pd
import numpy as np
import os
import yaml
from typing import Dict, Any, Optional, Tuple, List
import torch
import torch.nn as nn

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

from __future__ import annotations
import math
from gymnasium import spaces
from pathlib import Path
import pandas_ta as ta

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  from pkg_resources import get_distribution, DistributionNotFound


In [20]:
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=730)
symbol_name = "XAUUSDm"
is_connected = False
timeframe = "M15"
data = pd.DataFrame()

Initialize MT5 and connect to MT5 Broker account

In [22]:
if not mt5.initialize(
    login=210888620,
    password="S@jasper&12345",
    server="Exness-MT5Trial9"
):
    logger.error(f"MT5 initialization failed: {mt5.last_error()}")
            
    # Get account info
    account_info = mt5.account_info()
    if account_info is None:
        logger.error("Failed to get account information")
    
        symbol_info = mt5.symbol_info(symbol_name)
        if symbol_info is None:
            logger.warning(f"Symbol {symbol_name} not found")
        else:
            mt5.symbol_select(symbol_name, True)
    
    is_connected = True
    logger.info(f"Connected to MT5: Account {account_info.login},"
                f"Balance: {account_info.balance}, "
                f"Server: {account_info.server}")
else:
    logger.info("MT5 initialization successful")
    is_connected = True

[32m2025-08-27 19:53:29.680[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mMT5 initialization successful[0m


Get historical market data.
        
Args:
    symbol: Trading symbol
    timeframe: Timeframe (M1, M5, M15, M30, H1, H4, D1)
    start_date: Start date for data
    end_date: End date for data
    count: Number of bars to retrieve
    
Returns:
    DataFrame with OHLCV data

In [23]:
if not is_connected:
    logger.error("Not connected to MT5")
else:        
    try:
        # Map timeframe strings to MT5 constants
        timeframe_map = {
            "M1": mt5.TIMEFRAME_M1,
            "M5": mt5.TIMEFRAME_M5,
            "M15": mt5.TIMEFRAME_M15,
            "M30": mt5.TIMEFRAME_M30,
            "H1": mt5.TIMEFRAME_H1,
            "H4": mt5.TIMEFRAME_H4,
            "D1": mt5.TIMEFRAME_D1
        }
        
        mt5_timeframe = timeframe_map.get(timeframe, mt5.TIMEFRAME_H1)
        
        # Get data
        rates = mt5.copy_rates_range(symbol_name, mt5_timeframe, start_dt, end_dt)
        
        if rates is None or len(rates) == 0:
            logger.warning(f"No data retrieved for {symbol_name}")
        
        # Convert to DataFrame
        df = pd.DataFrame(rates)
        df['time'] = pd.to_datetime(df['time'], unit='s')
        df.set_index('time', inplace=True)
        
        # Rename columns to standard format
        df.rename(columns={
            'open': 'open',
            'high': 'high', 
            'low': 'low',
            'close': 'close',
            'tick_volume': 'volume'
        }, inplace=True)
        
        logger.info(f"Retrieved {len(df)} bars for {symbol_name} ({timeframe})")
        data = df
        
    except Exception as e:
        logger.error(f"Error retrieving historical data: {e}")

[32m2025-08-27 19:53:35.665[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m38[0m - [1mRetrieved 47224 bars for XAUUSDm (M15)[0m


In [24]:
# Split data for training and validation
logger.info(f'Traning Data: {data}')

split_point = int(len(data) * 0.8)
train_data = data.iloc[:split_point]
eval_data = data.iloc[split_point:]

logger.info(f'Train Data: {train_data}')
logger.info(f'Evaluation Data: {eval_data}')

[32m2025-08-27 19:53:41.843[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mTraning Data:                          open      high       low     close  volume  spread  \
time                                                                          
2023-08-28 19:00:00  1919.521  1920.308  1919.521  1919.877     447     200   
2023-08-28 19:15:00  1919.876  1920.547  1919.875  1920.396     516     200   
2023-08-28 19:30:00  1920.398  1920.503  1919.774  1920.052     558     200   
2023-08-28 19:45:00  1920.032  1920.127  1919.457  1919.659     593     200   
2023-08-28 20:00:00  1919.676  1920.208  1919.676  1919.972     337     200   
...                       ...       ...       ...       ...     ...     ...   
2025-08-27 15:45:00  3388.838  3388.838  3386.470  3387.898    1453     160   
2025-08-27 16:00:00  3387.917  3389.719  3387.478  3388.930    1325     160   
2025-08-27 16:15:00  3388.944  3391.179  3388.111  3391.153    1354     160   
2025-08-27

State Feature Extractor
=======================

This module extracts and processes features for the trading environment state.
It handles technical indicators, price features, and time-based features.

Features include:
- Price data (OHLCV)
- Technical indicators (SMA, EMA, RSI, MACD, Bollinger Bands, ATR)
- Time features (hour, day of week, etc.)
- Normalized and scaled features

Author: PPO Trading System

In [32]:
config_path: str = "config/model_config.yaml"
# load main config
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
    
feature_config = config['environment']['features']

# Feature scaling parameters (learned during first extraction)
feature_stats = {}
is_fitted = False

# Track minimum data requirements for each indicator
min_periods = {
    'sma_5': 5, 'sma_20': 20, 'sma_50': 50,
    'ema_12': 12, 'ema_26': 26,
    'rsi_14': 14, 'macd': 26, 'atr_14': 14,
    'bb_upper': 20, 'bb_lower': 20
}

logger.info("State feature extractor initialized")


[32m2025-08-27 19:58:25.130[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mState feature extractor initialized[0m


In [39]:
def _validate_input_data(data: pd.DataFrame) -> bool:
    required_cols = ['open', 'high', 'low', 'close', 'volume']
    for col in required_cols:
        if col not in data.columns:
            logger.error(f"Missing required column: {col}")
            return False
    
    # Check for NaN in essential columns
    essential_nan = data[required_cols].isnull().sum().sum()
    if essential_nan > 0:
        logger.warning(f"Found {essential_nan} NaN values in essential columns")
        # Fill essential NaN with forward then backward fill
        data[required_cols] = data[required_cols].ffill().bfill()
    
    return True
    
def _add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
    indicators = feature_config.get('technical_indicators', [])
    n_rows = len(df)
    
    for indicator in indicators:
        try:
            # Skip indicators that require more data than available
            min_periods = min_periods.get(indicator, 1)
            if n_rows < min_periods:
                logger.warning(f"Skipping {indicator}: {n_rows} < {min_periods}")
                continue
            
            if ta:
                # Use pandas_ta if available
                if indicator == 'sma_5':
                    df['sma_5'] = ta.sma(df['close'], length=5)
                elif indicator == 'sma_20':
                    df['sma_20'] = ta.sma(df['close'], length=20)
                elif indicator == 'sma_50':
                    df['sma_50'] = ta.sma(df['close'], length=50)
                elif indicator == 'ema_12':
                    df['ema_12'] = ta.ema(df['close'], length=12)
                elif indicator == 'ema_26':
                    df['ema_26'] = ta.ema(df['close'], length=26)
                elif indicator == 'rsi_14':
                    df['rsi_14'] = ta.rsi(df['close'], length=14)
                elif indicator == 'macd':
                    macd_df = ta.macd(df['close'], fast=12, slow=26, signal=9)
                    df['macd'] = macd_df['MACD_12_26_9']
                    df['macd_signal'] = macd_df['MACDs_12_26_9']
                elif indicator in ['bb_upper', 'bb_lower']:
                    bbands = ta.bbands(df['close'], length=20, std=2)
                    df['bb_upper'] = bbands['BBU_20_2.0']
                    df['bb_lower'] = bbands['BBL_20_2.0']
                elif indicator == 'atr_14':
                    df['atr_14'] = ta.atr(df['high'], df['low'], df['close'], length=14)
        except Exception as e:
            logger.warning(f"Failed to calculate indicator {indicator}: {e}")
    
    # Add price-based features with robust calculation
    df['price_change'] = df['close'].pct_change()
    
    # Handle division by zero in high_low_ratio
    high_low_ratio = np.where(
        df['low'] != 0, 
        df['high'] / df['low'], 
        np.where(df['high'] != 0, 1.0, 1.0)
    )
    df['high_low_ratio'] = high_low_ratio
    
    # Volume features with robust handling
    df['volume_sma'] = ta.sma(df['volume'], length=20) if ta else df['volume'].rolling(window=20).mean()
    
    # Handle division by zero in volume_ratio
    volume_ratio = np.where(
        (df['volume_sma'] != 0) & (~np.isclose(df['volume_sma'], 0)),
        df['volume'] / df['volume_sma'],
        np.where(df['volume'] != 0, 1.0, 0.0)
    )
    df['volume_ratio'] = volume_ratio
    
    # Volatility features
    df['volatility'] = df['price_change'].rolling(window=20, min_periods=1).std()
    
    # Handle division by zero in price_position
    price_range = df['high'] - df['low']
    price_position = np.where(
        price_range != 0,
        (df['close'] - df['low']) / price_range,
        np.where(df['close'] != 0, 0.5, 0.5)
    )
    df['price_position'] = price_position
    
    return df

def _handle_nan_values(df: pd.DataFrame) -> pd.DataFrame:
    # Fill NaN values with appropriate methods
    for col in df.columns:
        if df[col].isnull().any():
            # For price-based columns, use forward/backward fill
            if col in ['open', 'high', 'low', 'close', 'volume']:
                df[col] = df[col].ffill().bfill()
            # For technical indicators, use the mean of the column
            else:
                col_mean = df[col].mean()
                if not pd.isna(col_mean):
                    df[col] = df[col].fillna(col_mean)
                else:
                    # If mean is also NaN, use a default value
                    df[col] = df[col].fillna(0)
    
    # Check for infinite values
    inf_mask = np.isinf(df.values)
    if np.any(inf_mask):
        logger.warning(f"Found {inf_mask.sum()} infinite values in features")
        df = df.replace([np.inf, -np.inf], np.nan)
        df = df.ffill().bfill().fillna(0)
    
    return df

def _normalize_features(data: pd.DataFrame, is_fitted: bool) -> pd.DataFrame:
    """Normalize features using z-score normalization with robust handling."""
    df = data.copy()
    
    # Features that should not be normalized (already in good range)
    skip_normalization = ['is_weekend', 'is_month_end', 'is_quarter_end', 
                        'hour_sin', 'hour_cos', 'dow_sin', 'dow_cos']
    
    # Features that should be normalized
    normalize_features = [col for col in df.select_dtypes(include=[np.number]).columns 
                        if col not in skip_normalization]
    
    if not is_fitted:
        # Calculate statistics on first run
        feature_stats = {}
        for feature in normalize_features:
            if feature in df.columns:
                values = df[feature].dropna()
                if len(values) > 0:
                    # Handle cases where std is zero or very small
                    std = values.std()
                    if std < 1e-8:  # Very small standard deviation
                        std = 1.0  # Avoid division by zero
                    
                    feature_stats[feature] = {
                        'mean': values.mean(),
                        'std': std
                    }
                else:
                    # If no valid values, use defaults
                    feature_stats[feature] = {'mean': 0, 'std': 1.0}
        is_fitted = True
        logger.info(f"Feature normalization parameters fitted for {len(feature_stats)} features")
    
    # Apply normalization
    for feature, stats in feature_stats.items():
        if feature in df.columns:
            # Handle cases where feature values are constant
            if stats['std'] < 1e-8:
                df[feature] = 0  # Set to zero if no variation
            else:
                df[feature] = (df[feature] - stats['mean']) / stats['std']
    
    return df

def _add_time_features(data: pd.DataFrame) -> pd.DataFrame:
    """Add time-based features."""
    df = data.copy()
    
    time_features = feature_config.get('time_features', [])
    
    # Ensure index is datetime
    if not isinstance(df.index, pd.DatetimeIndex):
        # common timestamp column names to try
        ts_candidates = [c for c in ('timestamp', 'time', 'date', 'datetime') if c in df.columns]
        coerced = False
        for col in ts_candidates:
            try:
                df.index = pd.to_datetime(df[col])
                coerced = True
                break
            except Exception:
                continue

        if not coerced:
            # Try to coerce the existing index
            try:
                df.index = pd.to_datetime(df.index)
                coerced = True
            except Exception:
                coerced = False

        if not coerced:
            logger.warning("Could not coerce a DatetimeIndex from data. "
                        "Time features will be omitted. Provide a DatetimeIndex or a 'timestamp' column.")
            
    for feature in time_features:
        try:
            if feature == 'hour_of_day':
                df['hour_of_day'] = df.index.hour
            elif feature == 'day_of_week':
                df['day_of_week'] = df.index.dayofweek
            elif feature == 'day_of_month':
                df['day_of_month'] = df.index.day
            elif feature == 'month_of_year':
                df['month_of_year'] = df.index.month
            elif feature == 'is_weekend':
                df['is_weekend'] = (df.index.dayofweek >= 5).astype(int)
            elif feature == 'is_month_end':
                df['is_month_end'] = df.index.is_month_end.astype(int)
            elif feature == 'is_quarter_end':
                df['is_quarter_end'] = df.index.is_quarter_end.astype(int)
                
        except Exception as e:
            logger.warning(f"Failed to calculate time feature {feature}: {e}")
    
    # Cyclical encoding for time features
    if 'hour_of_day' in df.columns:
        df['hour_sin'] = np.sin(2 * np.pi * df['hour_of_day'] / 24)
        df['hour_cos'] = np.cos(2 * np.pi * df['hour_of_day'] / 24)
    
    if 'day_of_week' in df.columns:
        df['dow_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
        df['dow_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
    
    return df

def _select_features(data: pd.DataFrame) -> pd.DataFrame:
    """Select only the configured features."""
    # Get all configured features
    price_features = feature_config.get('price_features', [])
    technical_indicators = feature_config.get('technical_indicators', [])
    time_features = feature_config.get('time_features', [])
    
    # Add derived features
    derived_features = ['price_change', 'high_low_ratio', 'volume_ratio', 'volatility', 'price_position']
    
    # Do NOT add cyclical_features manually, they are already created in _add_time_features
    all_features = (price_features + technical_indicators + time_features + derived_features)
    
    # Deduplicate while preserving order
    all_features = list(dict.fromkeys(all_features))
    
    # Select only existing features
    existing_features = [f for f in all_features if f in data.columns]
    
    if not existing_features:
        logger.warning("No configured features found in data, using all numeric columns")
        existing_features = list(data.select_dtypes(include=[np.number]).columns)
    
    logger.info(f"Selected {len(existing_features)} features: {existing_features}")
    return data[existing_features]

def get_feature_importance(model, feature_names: List[str]) -> Dict[str, float]:
    """
    Calculate feature importance if the model supports it.
    
    Args:
        model: Trained model with feature_importances_ attribute
        feature_names: List of feature names
        
    Returns:
        Dictionary mapping feature names to importance scores
    """
    try:
        if hasattr(model, 'feature_importances_'):
            importances = model.feature_importances_
            return dict(zip(feature_names, importances))
        else:
            logger.warning("Model does not support feature importance calculation")
            return {}
    except Exception as e:
        logger.error(f"Error calculating feature importance: {e}")
        return {}



"""
Reward Functions
================

This module implements various reward functions for the PPO trading agent.
The reward function is crucial for training the agent to make profitable trades.

Available reward strategies:
- Profit-based rewards
- Sharpe ratio-based rewards
- Risk-adjusted returns
- Custom composite rewards

Author: PPO Trading System
"""

In [40]:
reward_config = config['environment']['reward_function']
params = reward_config['parameters']

# Tracking variables
previous_portfolio_value = None
returns_history = []
max_returns_history = 1000

logger.info(f"Reward calculator initialized")

[32m2025-08-27 20:17:23.100[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mReward calculator initialized[0m


In [None]:
def _adapt_parameters() -> None:
    """Adapt reward parameters based on recent performance."""
    if len(returns_history) < 50:
        return
    
    recent_returns = np.array(returns_history[-50:])
    
    # Analyze performance
    mean_return = np.mean(recent_returns)
    volatility = np.std(recent_returns)
    sharpe = mean_return / (volatility + 1e-8)
    
    # Adapt drawdown penalty based on recent volatility
    if volatility > 0.05:  # High volatility
        params['drawdown_penalty'] = min(params['drawdown_penalty'] * 1.1, 5.0)
    elif volatility < 0.01:  # Low volatility
        params['drawdown_penalty'] = max(params['drawdown_penalty'] * 0.9, 0.1)
    
    # Adapt profit weight based on Sharpe ratio
    if sharpe > 1.0:  # Good performance
        params['profit_weight'] = min(params['profit_weight'] * 1.05, 2.0)
    elif sharpe < -0.5:  # Poor performance
        params['profit_weight'] = max(params['profit_weight'] * 0.95, 0.5)
    
    logger.info(f"Adapted reward parameters: profit_weight={params['profit_weight']:.3f}, "
                f"drawdown_penalty={params['drawdown_penalty']:.3f}")

def _calculate_reward(
                      position: float,
                      drawdown: float,
                      portfolio_value: float) -> float:
        """
        Calculate custom composite reward.
        
        This combines multiple reward components for more sophisticated training.
        """
        if previous_portfolio_value is None:
            previous_portfolio_value = portfolio_value
            return 0.0
        
        # Base profit reward
        step_return = (portfolio_value - previous_portfolio_value) / previous_portfolio_value
        profit_reward = step_return * params['profit_weight']
        
        # Custom reward components
        custom_rewards = reward_config.get('custom_rewards', {})
        custom_reward_total = 0.0
        
        # Trend following reward
        if 'trend_following' in custom_rewards and len(returns_history) > 5:
            recent_returns = returns_history[-5:]
            trend_strength = np.mean(recent_returns)
            position_alignment = np.sign(position) * np.sign(trend_strength)
            trend_reward = position_alignment * abs(trend_strength) * custom_rewards['trend_following']
            custom_reward_total += trend_reward
        
        # Mean reversion reward
        if 'mean_reversion' in custom_rewards and len(returns_history) > 20:
            recent_returns = np.array(returns_history[-20:])
            z_score = (recent_returns[-1] - np.mean(recent_returns)) / (np.std(recent_returns) + 1e-8)
            # Reward betting against extreme moves
            mean_reversion_signal = -np.sign(z_score) if abs(z_score) > 1.5 else 0
            position_alignment = np.sign(position) * mean_reversion_signal
            mr_reward = position_alignment * custom_rewards['mean_reversion']
            custom_reward_total += mr_reward
        
        # Volatility targeting reward
        if 'volatility_targeting' in custom_rewards and len(returns_history) > 10:
            recent_volatility = np.std(returns_history[-10:])
            target_volatility = 0.02  # 2% daily volatility target
            vol_adjustment = 1.0 - abs(recent_volatility - target_volatility) / target_volatility
            vol_reward = vol_adjustment * custom_rewards['volatility_targeting']
            custom_reward_total += vol_reward
        
        # Risk penalties
        drawdown_penalty = drawdown * params['drawdown_penalty']
        transaction_penalty = abs(position) * params['transaction_cost'] * params['penalty_weight']
        
        # Combine all components
        total_reward = (profit_reward + custom_reward_total - 
                       drawdown_penalty - transaction_penalty)
        
        # Update tracking
        previous_portfolio_value = portfolio_value
        returns_history.append(step_return)
        
        if len(returns_history) > max_returns_history:
            returns_history.pop(0)

        _adapt_parameters()
        
        return float(total_reward)
    


In [43]:
logger.info("Extracting features from market data")
        
# Validate input data first
if not _validate_input_data(data):
    logger.error("Invalid input data for feature extraction")

# Add technical indicators with NaN handling
data = _add_technical_indicators(data)
logger.info(f"Added Technical Indicator: {data}")

# Add time features
data = _add_time_features(data)
logger.info(f"Added Time Features: {data}")

# Normalize features
data = _normalize_features(data, is_fitted)
logger.info(f"Normalized Features: {data}")

# Select only the configured features
data = _select_features(data)
logger.info(f"Selected Features: {data}")

# Handle any remaining NaN values
data = _handle_nan_values(data)
logger.info(f"Handled NaN Values: {data}")

logger.info(f"Feature extraction complete. Shape: {data.shape}")

[32m2025-08-27 20:18:23.685[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mExtracting features from market data[0m
[32m2025-08-27 20:18:23.936[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mAdded Technical Indicator:                          open      high       low     close    volume  \
time                                                                    
2023-08-28 19:00:00 -1.334567 -1.334879 -1.332644 -1.333872 -0.915339   
2023-08-28 19:15:00 -1.333830 -1.334383 -1.331908 -1.332794 -0.871636   
2023-08-28 19:30:00 -1.332746 -1.334474 -1.332118 -1.333509 -0.845035   
2023-08-28 19:45:00 -1.333506 -1.335254 -1.332777 -1.334325 -0.822867   
2023-08-28 20:00:00 -1.334245 -1.335086 -1.332321 -1.333675 -0.985010   
...                       ...       ...       ...       ...       ...   
2025-08-27 15:45:00  1.716620  1.711297  1.717232  1.714590 -0.278167   
2025-08-27 16:00:00  1.714707  1.713125  1.719328  1.716733 -

Create training Model

[32m2025-08-27 22:56:11.687[0m | [1mINFO    [0m | [36m__main__[0m:[36m_setup_spaces[0m:[36m70[0m - [1mAction space: Discrete(3)[0m
[32m2025-08-27 22:56:11.689[0m | [1mINFO    [0m | [36m__main__[0m:[36m_setup_spaces[0m:[36m71[0m - [1mObservation space: (1705,)[0m
[32m2025-08-27 22:56:11.752[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m174[0m - [1mTrading environment initialized with 47224 data points[0m


In [74]:
def trading_env_init(
    data: pd.DataFrame,
    config_path: str = "config/model_config.yaml",
    initial_balance: float = 10_000.0,
    transaction_cost: float = 0.0001,
    render_mode: Optional[str] = None,
    expected_signature: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
    """
    Initialize a trading environment state dictionary.
    
    Returns a state dictionary that contains all environment parameters and state.
    """
    # Load main config
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    # Create state dictionary
    state = {
        "config": config,
        "raw_data": data.copy() if data is not None else pd.DataFrame(),
        "initial_balance": float(initial_balance),
        "transaction_cost": float(transaction_cost),
        "render_mode": render_mode,
        "_expected_signature": expected_signature,
    }

    # Environment params from config as defaults
    env_cfg = config.get("environment", {})
    state["action_config"] = env_cfg.get("action_space", {"type": "discrete", "actions": ["hold", "buy", "sell"]})
    state["reward_config"] = env_cfg.get("reward_function", {})
    
    # Backtest config
    state["backtest_cfg"] = config.get("backtesting", {})

    # Load persisted observation signature if present
    obs_config_path = Path("models/saved_models/obs_config.yaml")
    if obs_config_path.exists():
        with open(obs_config_path, "r") as f:
            obs_config = yaml.safe_load(f) or {}
    else:
        obs_config = {}

    # Persistent observation params, fall back to config-derived defaults
    state["feature_names"] = obs_config.get("feature_names", [])
    state["n_features"] = int(obs_config.get("n_features", len(state["feature_names"]))) if obs_config else 0
    state["lookback_window"] = int(obs_config.get("lookback_window", env_cfg.get("lookback_window", 100)))
    state["portfolio_features_dim"] = int(obs_config.get("portfolio_features_dim", 5))
    
    # Compute obs_dim if present, otherwise compute from lookback and n_features
    state["obs_dim"] = int(obs_config.get("obs_dim", 
        state["lookback_window"] * max(1, state["n_features"]) + state["portfolio_features_dim"]))

    # Prepare data and validate against obs_config when available
    state = _prepare_data(state)

    # If feature names were not provided by obs_config, derive them now
    if not state["feature_names"]:
        state["feature_names"] = list(state["processed_data"].columns)
        state["n_features"] = int(len(state["feature_names"]))
        state["obs_dim"] = int(state["lookback_window"] * state["n_features"] + state["portfolio_features_dim"])

    # Validate observation signature consistency
    expected_obs_dim = state["lookback_window"] * state["n_features"] + state["portfolio_features_dim"]
    if state["obs_dim"] != expected_obs_dim:
        logger.warning(f"Observation dimension mismatch: config={state['obs_dim']}, calculated={expected_obs_dim}")
        state["obs_dim"] = expected_obs_dim

    # Build action and observation spaces
    state = _setup_spaces(state)

    # Reset internal state
    state = trading_env_reset(state)

    logger.info(f"Trading environment initialized with {len(state['raw_data'])} data points")
    
    return state

def trading_env_obs_signature(state: Dict[str, Any]) -> Dict[str, Any]:
    """Get observation signature from environment state."""
    return {
        "lookback_window": int(state["lookback_window"]),
        "n_features": int(state["n_features"]),
        "feature_names": list(state["feature_names"]),
        "portfolio_features_dim": int(state["portfolio_features_dim"]),
        "obs_dim": int(state["obs_dim"])
    }

def _prepare_data(state: Dict[str, Any]) -> Dict[str, Any]:
    """Prepare and validate data for the trading environment."""
    required_columns = ["open", "high", "low", "close", "volume"]
    if not isinstance(state["raw_data"], pd.DataFrame):
        raise ValueError("data must be a pandas DataFrame")

    missing = [c for c in required_columns if c not in state["raw_data"].columns]
    if missing:
        raise ValueError(f"Data must contain columns: {required_columns}. Missing: {missing}")

    # Extract features (in a real implementation, this would be more complex)
    state["processed_data"] = state["raw_data"]

    # Drop NaNs from indicator warmup
    state["processed_data"] = state["processed_data"].dropna(axis=0, how="any").reset_index(drop=True)

    if len(state["processed_data"]) < state["lookback_window"]:
        raise ValueError(f"Insufficient data after feature extraction: need at least lookback_window ({state['lookback_window']}) rows, got {len(state['processed_data'])}")

    # If obs_config provided feature_names, ensure processed_data contains them
    if state["feature_names"]:
        missing_feats = [f for f in state["feature_names"] if f not in state["processed_data"].columns]

        if missing_feats:
            logger.warning(f"Processed data missing features required by obs_config.yaml: {missing_feats}")
            # Use available features instead of failing
            available_features = [f for f in state["feature_names"] if f in state["processed_data"].columns]
            if len(available_features) < len(state["feature_names"]) * 0.8:  # If less than 80% features available
                raise ValueError(f"Too many missing features. Available: {len(available_features)}, Required: {len(state['feature_names'])}")
            
            logger.info(f"Using {len(available_features)} available features out of {len(state['feature_names'])} expected")
            state["feature_names"] = available_features
            
        # Ensure n_features matches
        state["n_features"] = int(len(state["feature_names"]))
    else:
        # Derive feature names from processed_data
        state["feature_names"] = list(state["processed_data"].columns)
        state["n_features"] = int(len(state["feature_names"]))

    # Recompute obs_dim from final values
    state["obs_dim"] = int(state["lookback_window"] * state["n_features"] + state["portfolio_features_dim"])

    # Validate against expected signature if supplied
    if state["_expected_signature"]:
        expected_obs_dim = int(state["_expected_signature"].get("obs_dim", -1))
        expected_n_features = int(state["_expected_signature"].get("n_features", -1))
        expected_lookback = int(state["_expected_signature"].get("lookback_window", -1))
        
        logger.info(f"Validating against expected signature:")
        logger.info(f"  Expected obs_dim: {expected_obs_dim}, Current: {state['obs_dim']}")
        logger.info(f"  Expected n_features: {expected_n_features}, Current: {state['n_features']}")
        logger.info(f"  Expected lookback: {expected_lookback}, Current: {state['lookback_window']}")
        
        if int(state["_expected_signature"].get("obs_dim", -1)) != int(state["obs_dim"]):
            # Try to adjust to match expected signature
            expected_obs_dim = int(state["_expected_signature"].get("obs_dim", -1))
            expected_n_features = int(state["_expected_signature"].get("n_features", -1))
            expected_lookback = int(state["_expected_signature"].get("lookback_window", -1))
            
            if expected_lookback > 0 and expected_n_features > 0:
                logger.warning(f"Adjusting environment to match expected signature")
                state["lookback_window"] = expected_lookback
                
                # If we have more features than expected, select the first N
                if state["n_features"] > expected_n_features:
                    expected_feature_names = state["_expected_signature"].get("feature_names", [])
                    if expected_feature_names:
                        # Use expected feature names if available
                        available_expected = [f for f in expected_feature_names if f in state["processed_data"].columns]
                        if len(available_expected) >= expected_n_features:
                            state["feature_names"] = available_expected[:expected_n_features]
                        else:
                            # Fallback to first N features
                            state["feature_names"] = state["feature_names"][:expected_n_features]
                    else:
                        state["feature_names"] = state["feature_names"][:expected_n_features]
                    
                    state["n_features"] = len(state["feature_names"])
                
                # Recalculate obs_dim
                state["obs_dim"] = int(state["lookback_window"] * state["n_features"] + state["portfolio_features_dim"])
                
                logger.info(f"Adjusted to: obs_dim={state['obs_dim']}, n_features={state['n_features']}, lookback={state['lookback_window']}")
                
                if state["obs_dim"] != expected_obs_dim:
                    raise ValueError(
                        f"Cannot adjust environment to match expected signature. "
                        f"Expected obs_dim={expected_obs_dim}, achieved obs_dim={state['obs_dim']}"
                    )
            else:
                raise ValueError(
                    f"Observation signature mismatch. Expected obs_dim={expected_obs_dim}, "
                    f"but environment produces obs_dim={state['obs_dim']}"
                )

    logger.info(f"Data prepared: {len(state['processed_data'])} samples, {len(state['processed_data'].columns)} features")
    
    return state

def _setup_spaces(state: Dict[str, Any]) -> Dict[str, Any]:
    """Set up action and observation spaces."""
    # Action space
    if state["action_config"].get("type", "discrete") == "discrete":
        actions = list(state["action_config"].get("actions", ["hold", "buy", "sell"]))
        state["action_space"] = spaces.Discrete(len(actions))
    else:
        bounds = state["action_config"].get("continuous_bounds", {"position_size": [-1.0, 1.0]})
        low = float(bounds["position_size"][0])
        high = float(bounds["position_size"][1])
        state["action_space"] = spaces.Box(
            low=np.array([low], dtype=np.float32),
            high=np.array([high], dtype=np.float32),
            dtype=np.float32,
        )

    # Observation space
    obs_dim = int(state["lookback_window"] * state["n_features"] + state["portfolio_features_dim"])
    state["obs_dim"] = obs_dim
    state["observation_space"] = spaces.Box(
        low=-np.inf,
        high=np.inf,
        shape=(obs_dim,),
        dtype=np.float32,
    )

    logger.info(f"Action space: {state['action_space']}")
    logger.info(f"Observation space: {state['observation_space'].shape}")
    
    return state

def trading_env_reset(state: Dict[str, Any], seed: Optional[int] = None, options: Optional[Dict] = None) -> Dict[str, Any]:
    """Reset the trading environment state."""
    # Reset internal state
    state["current_step"] = int(state["lookback_window"])
    state["balance"] = float(state["initial_balance"])
    state["position"] = 0.0
    state["position_entry_price"] = 0.0
    state["total_trades"] = 0
    state["winning_trades"] = 0
    state["trade_history"] = []
    state["portfolio_values"] = [state["initial_balance"]]
    state["unrealized_pnl"] = 0.0

    state["max_portfolio_value"] = state["initial_balance"]
    state["current_drawdown"] = 0.0
    state["max_drawdown"] = 0.0
    
    return state

def trading_env_step(state: Dict[str, Any], action: Any) -> Tuple[Dict[str, Any], np.ndarray, float, bool, bool, Dict[str, Any]]:
    """Take a step in the trading environment."""
    state = _execute_action(state, action)
    state["current_step"] += 1

    reward = _calculate_reward(
        position=state["position"],
        drawdown=state["current_drawdown"],
        portfolio_value=_get_portfolio_value(state)
    )

    state = _update_portfolio_tracking(state)

    terminated = (state["current_step"] >= (len(state["processed_data"]) - 1))
    truncated = _check_early_termination(state)

    obs = _get_observation(state)
    info = _get_info(state)
    
    return state, obs, float(reward), bool(terminated), bool(truncated), info

def _execute_action(state: Dict[str, Any], action: Any) -> Dict[str, Any]:
    """Execute a trading action."""
    current_price = _get_current_price(state)

    if state["action_config"].get("type", "discrete") == "discrete":
        if not isinstance(action, (int, np.integer)):
            if isinstance(action, (list, tuple, np.ndarray)):
                action = int(np.asarray(action).flatten()[0])
            else:
                action = int(action)

        actions = state["action_config"].get("actions", ["hold", "buy", "sell"])
        idx = int(action) % len(actions)
        action_name = actions[idx]

        if action_name == "buy" and state["position"] <= 0:
            state = _open_position(state, 1.0, current_price)
        elif action_name == "sell" and state["position"] >= 0:
            state = _open_position(state, -1.0, current_price)
        elif action_name == "hold":
            if state["position"] != 0.0:
                state["unrealized_pnl"] = (current_price - state["position_entry_price"]) * state["position"]
    else:
        if isinstance(action, (list, tuple, np.ndarray)):
            target_position = float(np.asarray(action).flatten()[0])
        else:
            target_position = float(action)
        if not math.isfinite(target_position):
            return state
        if target_position != state["position"]:
            state = _adjust_position(state, target_position, current_price)
    
    return state

def _open_position(state: Dict[str, Any], direction: float, price: float) -> Dict[str, Any]:
    """Open a new position."""
    if state["position"] != 0.0:
        state = _close_position(state, price)

    sizing = state["backtest_cfg"].get("position_sizing", {"method": "fixed_fraction", "fraction": 0.01, "max_position": 0.1})

    if sizing.get("method", "fixed_fraction") == "fixed_fraction":
        fraction = float(sizing.get("fraction", 0.01))
        position_value = state["balance"] * fraction
        position_size = (position_value / price) * float(direction)
    elif sizing.get("method") == "fixed":
        position_size = float(sizing.get("size", 0.01)) * float(direction)
    else:
        position_size = 0.01 * float(direction)

    max_pos = float(sizing.get("max_position", 0.1))
    max_pos_val = state["balance"] * max_pos
    max_pos_size = max_pos_val / price if price > 0 else max_pos_val
    if abs(position_size) > max_pos_size:
        position_size = math.copysign(max_pos_size, position_size)

    trade_cost = abs(position_size) * price * state["transaction_cost"]

    if state["balance"] >= trade_cost:
        state["position"] = float(position_size)
        state["position_entry_price"] = float(price)
        state["balance"] -= float(trade_cost)
        state["total_trades"] += 1
        state["unrealized_pnl"] = 0.0

        state["trade_history"].append({
            "step": state["current_step"],
            "action": "BUY" if direction > 0 else "SELL",
            "size": state["position"],
            "price": price,
            "cost": trade_cost,
            "balance_after": state["balance"]
        })
    
    return state

def _close_position(state: Dict[str, Any], price: float) -> Dict[str, Any]:
    """Close the current position."""
    if state["position"] == 0.0:
        return state

    pnl = (price - state["position_entry_price"]) * state["position"]
    trade_cost = abs(state["position"]) * price * state["transaction_cost"]
    net_pnl = pnl - trade_cost

    state["balance"] += state["position"] * price + net_pnl

    if net_pnl > 0:
        state["winning_trades"] += 1

    state["trade_history"].append({
        "step": state["current_step"],
        "action": "CLOSE",
        "size": state["position"],
        "entry_price": state["position_entry_price"],
        "exit_price": price,
        "pnl": net_pnl,
        "balance_after": state["balance"]
    })

    state["position"] = 0.0
    state["position_entry_price"] = 0.0
    state["unrealized_pnl"] = 0.0
    
    return state

def _adjust_position(state: Dict[str, Any], target_position: float, price: float) -> Dict[str, Any]:
    """Adjust the current position to a target position."""
    if target_position == state["position"]:
        return state

    delta = target_position - state["position"]
    trade_cost = abs(delta) * price * state["transaction_cost"]
    if state["balance"] >= trade_cost:
        if state["position"] == 0.0 and target_position != 0.0:
            state["position_entry_price"] = price
        elif state["position"] != 0.0 and target_position != 0.0:
            prev_val = state["position"] * state["position_entry_price"]
            added_val = delta * price
            new_pos = state["position"] + delta
            if new_pos != 0:
                state["position_entry_price"] = (prev_val + added_val) / new_pos
            else:
                state["position_entry_price"] = 0.0

        state["position"] = float(target_position)
        state["balance"] -= float(trade_cost)
        state["total_trades"] += 1
    
    return state

def _get_current_price(state: Dict[str, Any]) -> float:
    """Get the current price from processed data."""
    try:
        return float(state["processed_data"].iloc[state["current_step"]]["close"])
    except Exception:
        return float(state["processed_data"].iloc[state["current_step"]].iloc[-1])

def _get_portfolio_value(state: Dict[str, Any]) -> float:
    """Calculate the current portfolio value."""
    current_price = _get_current_price(state)
    position_value = state["position"] * current_price if state["position"] != 0.0 else 0.0
    return float(state["balance"] + position_value)

def _update_portfolio_tracking(state: Dict[str, Any]) -> Dict[str, Any]:
    """Update portfolio tracking metrics."""
    pv = _get_portfolio_value(state)
    state["portfolio_values"].append(pv)

    if pv > state["max_portfolio_value"]:
        state["max_portfolio_value"] = pv

    if state["max_portfolio_value"] > 0:
        state["current_drawdown"] = (state["max_portfolio_value"] - pv) / float(state["max_portfolio_value"])
    else:
        state["current_drawdown"] = 0.0

    if state["current_drawdown"] > state["max_drawdown"]:
        state["max_drawdown"] = state["current_drawdown"]
    
    return state

def _check_early_termination(state: Dict[str, Any]) -> bool:
    """Check if early termination conditions are met."""
    risk_cfg = state["backtest_cfg"].get("risk_management", {})
    max_drawdown_stop = float(risk_cfg.get("max_drawdown_stop", 0.5))
    if state["current_drawdown"] > max_drawdown_stop:
        return True

    if state["balance"] < (0.1 * state["initial_balance"]):
        return True

    return False

def _get_observation(state: Dict[str, Any]) -> np.ndarray:
    """Get the current observation."""
    start_idx = max(0, state["current_step"] - state["lookback_window"])

    # Select features in the order defined by feature_names
    df = state["processed_data"]
    if all(f in df.columns for f in state["feature_names"]):
        window = df[state["feature_names"]].iloc[start_idx:state["current_step"]]
        n_cols = len(state["feature_names"])
    else:
        # Fallback to all columns if mismatch
        logger.warning(f"Feature mismatch in observation. Expected: {state['feature_names']}, Available: {list(df.columns)}")
        window = df.iloc[start_idx:state["current_step"]]
        n_cols = window.shape[1]

    # Pad if needed
    if len(window) < state["lookback_window"]:
        n_missing = state["lookback_window"] - len(window)
        pad_shape = (n_missing, n_cols)
        pad = np.zeros(pad_shape, dtype=np.float32)
        window_vals = np.vstack([pad, window.values.astype(np.float32)])
    else:
        window_vals = window.values.astype(np.float32)

    flat_market = window_vals.flatten()

    portfolio_features = np.array([
        state["balance"] / float(state["initial_balance"]),
        state["position"],
        (state["unrealized_pnl"] / float(state["initial_balance"])) if state["initial_balance"] != 0 else 0.0,
        float(state["total_trades"]) / 1000.0,
        float(state["current_drawdown"])
    ], dtype=np.float32)

    obs = np.concatenate([flat_market.astype(np.float32), portfolio_features], axis=0)

    # Ensure obs length matches expected obs_dim, pad or trim if necessary to keep deterministic shape
    expected_len = int(state["observation_space"].shape[0]) if "observation_space" in state else int(state["obs_dim"])
    if obs.shape[0] < expected_len:
        pad_len = expected_len - obs.shape[0]
        obs = np.concatenate([obs, np.zeros(pad_len, dtype=np.float32)], axis=0)
    elif obs.shape[0] > expected_len:
        obs = obs[:expected_len]
        
    # Final validation
    if obs.shape[0] != expected_len:
        logger.error(f"Observation shape mismatch: got {obs.shape[0]}, expected {expected_len}")
        logger.error(f"Market features: {flat_market.shape[0]}, Portfolio features: {portfolio_features.shape[0]}")
        logger.error(f"Lookback window: {state['lookback_window']}, N features: {n_cols}")

    return obs

def _get_info(state: Dict[str, Any]) -> Dict[str, Any]:
    """Get information about the current state."""
    pv = _get_portfolio_value(state)
    win_rate = (float(state["winning_trades"]) / float(max(1, state["total_trades"]))) if state["total_trades"] > 0 else 0.0
    return {
        "step": int(state["current_step"]),
        "balance": float(state["balance"]),
        "position": float(state["position"]),
        "portfolio_value": float(pv),
        "total_return": float((pv - state["initial_balance"]) / float(state["initial_balance"])),
        "total_trades": int(state["total_trades"]),
        "winning_trades": int(state["winning_trades"]),
        "win_rate": float(win_rate),
        "max_drawdown": float(state["max_drawdown"]),
        "current_drawdown": float(state["current_drawdown"]),
        "unrealized_pnl": float(state["unrealized_pnl"])
    }

def trading_env_render(state: Dict[str, Any]) -> None:
    """Render the current environment state."""
    if state["render_mode"] == "human":
        info = _get_info(state)
        print(
            f"[Step {info['step']}] Balance: ${info['balance']:.2f} | "
            f"Portfolio: ${info['portfolio_value']:.2f} | Return: {info['total_return']:.2%} | "
            f"Drawdown: {info['current_drawdown']:.2%} | Position: {info['position']:.4f}"
        )

def trading_env_build_for_model_loading(
    data: pd.DataFrame,
    config_path: str,
    expected_signature: Dict[str, Any]
) -> Dict[str, Any]:
    """Build environment for model loading with signature validation."""
    state = trading_env_init(data=data, config_path=config_path, expected_signature=expected_signature)
    sig = trading_env_obs_signature(state)
    if sig.get("obs_dim") != expected_signature.get("obs_dim"):
        raise ValueError(
            f"obs_dim mismatch building env_for_model_loading: expected {expected_signature.get('obs_dim')}, got {sig.get('obs_dim')}"
        )
    return state

Setup and run Agent

In [75]:
class TradingFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space, portfolio_features_dim: int = 5):
        super().__init__(observation_space, features_dim=256)
        self.portfolio_features_dim = portfolio_features_dim

        total_obs_dim = int(observation_space.shape[0])
        market_dim = total_obs_dim - self.portfolio_features_dim
        if market_dim <= 0:
            raise ValueError("Observation space too small for configured portfolio_features_dim")

        self.market_extractor = nn.Sequential(
            nn.Linear(market_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU()
        )

        self.portfolio_extractor = nn.Sequential(
            nn.Linear(self.portfolio_features_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        self.combined_extractor = nn.Sequential(
            nn.Linear(128 + 32, 256),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        market_features = obs[:, :-self.portfolio_features_dim]
        portfolio_features = obs[:, -self.portfolio_features_dim:]
        market_embed = self.market_extractor(market_features)
        portfolio_embed = self.portfolio_extractor(portfolio_features)
        combined = torch.cat([market_embed, portfolio_embed], dim=1)
        return self.combined_extractor(combined)

In [76]:

OBS_CONFIG_FILENAME = "obs_config.yaml"
model_params = config['ppo']
training_params = config['training']

model: Optional[PPO] = None
env: Optional[gym.Env] = None
vec_env: Optional[VecNormalize] = None

# Track policy kwargs used and a serializable snapshot
_policy_kwargs_used: Optional[Dict[str, Any]] = None
_policy_kwargs_serializable: Optional[Dict[str, Any]] = None

def _read_yaml_if_exists(path: str) -> Optional[Dict[str, Any]]:
    if os.path.exists(path):
        try:
            with open(path, 'r') as f:
                return yaml.safe_load(f)
        except Exception as e:
            logger.warning(f"Failed to read yaml at {path}: {e}")
    return None

def _ensure_vec_env(env: Optional[gym.Env]) -> Optional[VecEnv]:
    if env is None:
        return None

    if isinstance(env, VecEnv):
        return env

    base = env
    try:
        underlying = getattr(base, "env", None)
        if underlying is not None and isinstance(underlying, gym.Env):
            base = base
    except Exception:
        pass

    try:
        monitored = base if isinstance(base, Monitor) else Monitor(base)
        return DummyVecEnv([lambda: monitored])
    except Exception as exc:
        raise ValueError(f"Failed to wrap provided env into VecEnv: {exc}")

def _unwrap_env_for_obs_signature(env: Optional[gym.Env]):
    if env is None:
        return None

    if isinstance(env, DummyVecEnv):
        try:
            inner = getattr(env, "envs", None)
            if inner and len(inner) > 0:
                return inner[0]
        except Exception:
            pass

    try:
        inner = getattr(env, "env", None)
        if inner is not None:
            return inner
    except Exception:
        pass

    return env

def _map_activation_fn(act):
    if isinstance(act, str):
        act_map = {
            "relu": nn.ReLU,
            "tanh": nn.Tanh,
            "sigmoid": nn.Sigmoid,
            "leaky_relu": nn.LeakyReLU,
            "swish": nn.SiLU,
            "gelu": nn.GELU
        }
        return act_map.get(act.lower(), nn.ReLU)
    if isinstance(act, type) and issubclass(act, nn.Module):
        return act
    return nn.ReLU

def _serialize_policy_kwargs(pk: Dict[str, Any]) -> Dict[str, Any]:
    serial = {}
    serial['net_arch'] = pk.get('net_arch')
    act = pk.get('activation_fn')
    if isinstance(act, type):
        serial['activation_fn'] = act.__name__.lower()
    else:
        serial['activation_fn'] = str(act).lower() if act is not None else None

    fe = pk.get('features_extractor_class')
    serial['features_extractor_class'] = fe.__name__ if hasattr(fe, '__name__') else str(fe)
    serial['features_extractor_kwargs'] = pk.get('features_extractor_kwargs', {})
    serial['ortho_init'] = pk.get('ortho_init', None)

    other = {k: v for k, v in pk.items() if k not in ['net_arch', 'activation_fn', 'features_extractor_class', 'features_extractor_kwargs', 'ortho_init']}
    serial['other'] = other
    serial['version'] = 1
    return serial

def _deserialize_policy_kwargs(serialized: Dict[str, Any]) -> Dict[str, Any]:
    if not serialized:
        return {}
    pk: Dict[str, Any] = {}
    if 'net_arch' in serialized and serialized['net_arch'] is not None:
        pk['net_arch'] = serialized['net_arch']
    act_name = serialized.get('activation_fn')
    if act_name:
        pk['activation_fn'] = _map_activation_fn(act_name)
    fe_name = serialized.get('features_extractor_class')
    if fe_name and fe_name.lower().startswith('tradingfeatureextractor'):
        pk['features_extractor_class'] = TradingFeatureExtractor
    pk['features_extractor_kwargs'] = serialized.get('features_extractor_kwargs', {}) or {}
    if 'ortho_init' in serialized:
        pk['ortho_init'] = serialized.get('ortho_init')
    other = serialized.get('other', {}) or {}
    pk.update(other)
    return pk

def _build_eval_vec_env(eval_data: pd.DataFrame, vec_env: Optional[VecNormalize], model_dir: Optional[str] = None) -> VecNormalize:
    expected_signature = None
    if model_dir:
        sig_path = os.path.join(model_dir, OBS_CONFIG_FILENAME)
        expected_signature = _read_yaml_if_exists(sig_path)

    eval_env = TradingEnvironment(
        data=eval_data,
        config_path="config/model_config.yaml",
        **({"expected_signature": expected_signature} if expected_signature is not None else {})
    )
    eval_env = Monitor(eval_env)
    eval_vec = DummyVecEnv([lambda: eval_env])

    eval_vec_norm = VecNormalize(eval_vec, norm_obs=True, norm_reward=False, training=False)

    if isinstance(vec_env, VecNormalize) and getattr(vec_env, "obs_rms", None) is not None:
        eval_vec_norm.obs_rms = vec_env.obs_rms

    return eval_vec_norm

In [66]:
VECNORM_FILENAME = "vecnormalize.pkl"
OBS_CONFIG_FILENAME = "obs_config.yaml"
POLICY_KWARGS_FILENAME = "policy_kwargs.yaml"

def train(
    model: Optional[PPO],
    vec_env: Optional[VecNormalize],
    env: Optional[gym.Env],
    config: Dict[str, Any],
    training_params: Dict[str, Any],
    policy_kwargs_serializable: Optional[Dict[str, Any]],
    train_data: pd.DataFrame,
    eval_data: Optional[pd.DataFrame] = None
) -> Tuple[Dict[str, Any], Optional[PPO], Optional[VecNormalize], Optional[gym.Env], Optional[Dict[str, Any]]]:
    """
    Train a PPO model with the provided data and parameters.
    
    Returns: (result_dict, updated_model, updated_vec_env, updated_env, updated_policy_kwargs_serializable)
    """
    if model is None:
        # Create model logic would need to be implemented separately
        raise ValueError("Model must be created before training")

    logger.info(f"Starting PPO training with {len(train_data)} samples")

    callbacks = [
        CheckpointCallback(
            save_freq=training_params.get('save_freq', 5000),
            save_path='./models/saved_models/checkpoints/',
            name_prefix='ppo_trading'
        )
    ]

    if eval_data is not None:
        eval_vec_norm = _build_eval_vec_env(eval_data, vec_env)
        callbacks.append(EvalCallback(
            eval_vec_norm,
            eval_freq=training_params.get('eval_freq', 10000),
            n_eval_episodes=training_params.get('n_eval_episodes', 5),
            best_model_save_path='./models/saved_models/',
            log_path='./logs/eval/',
            verbose=1
        ))

    total_timesteps = training_params['total_timesteps']

    try:
        model.learn(total_timesteps=total_timesteps, callback=callbacks, progress_bar=True)

        model_path = f"./models/saved_models/ppo_final.zip"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        model.save(model_path)

        try:
            vecnorm_path = os.path.join(os.path.dirname(model_path), VECNORM_FILENAME)
            if isinstance(vec_env, VecNormalize):
                vec_env.save(vecnorm_path)
                logger.info(f"Saved VecNormalize stats to {vecnorm_path}")
        except Exception as e:
            logger.warning(f"Could not save VecNormalize stats: {e}")

        try:
            base_env = _unwrap_env_for_obs_signature(env)
            obs_sig_fn = getattr(base_env, "obs_signature", None)
            if callable(obs_sig_fn):
                obs_sig = obs_sig_fn()
                with open(os.path.join(os.path.dirname(model_path), OBS_CONFIG_FILENAME), "w") as f:
                    yaml.safe_dump(obs_sig, f)
        except Exception as e:
            logger.warning(f"Could not save {OBS_CONFIG_FILENAME}: {e}")

        try:
            # prefer serializable snapshot created earlier
            policy_to_save = policy_kwargs_serializable or _serialize_policy_kwargs(config['ppo'].get('policy_kwargs', {}))
            with open(os.path.join(os.path.dirname(model_path), POLICY_KWARGS_FILENAME), "w") as f:
                yaml.safe_dump(policy_to_save, f)
        except Exception as e:
            logger.warning(f"Could not save {POLICY_KWARGS_FILENAME}: {e}")

        logger.info(f"Training completed. Model saved to {model_path}")

        training_cb = callbacks[0]
        result = {
            'success': True,
            'model_path': model_path,
            'total_timesteps': total_timesteps,
            'best_model_path': getattr(training_cb, 'best_model_path', None),
            'final_performance': {
                'mean_reward': training_cb.best_mean_reward,
                'total_episodes': len(training_cb.episode_rewards)
            }
        }

        return result, model, vec_env, env, policy_kwargs_serializable

    except Exception as e:
        logger.error(f"Training failed: {e}")
        return {'success': False, 'error': str(e)}, model, vec_env, env, policy_kwargs_serializable

def predict(
    model: PPO,
    observation: np.ndarray,
    deterministic: bool = True
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """
    Predict an action given an observation using the trained model.
    """
    if model is None:
        raise ValueError("Model not loaded. Call create_model() or train() first.")

    # Check for NaN in input
    if np.any(np.isnan(observation)):
        logger.warning("NaN values in observation input")
        observation = np.nan_to_num(observation)
    
    action, _ = model.predict(observation, deterministic=deterministic)
    return action, None

def evaluate(
    model: PPO,
    eval_data: pd.DataFrame,
    vec_env: Optional[VecNormalize],
    n_episodes: int = 10
) -> Dict[str, float]:
    """
    Evaluate the model on the provided evaluation data.
    """
    if model is None:
        raise ValueError("Model not loaded. Call create_model() or train() first.")

    model_dir = None
    eval_vec_norm = _build_eval_vec_env(eval_data, vec_env, model_dir=model_dir)

    episode_returns, episode_lengths, win_rates, max_drawdowns = [], [], [], []

    for episode in range(n_episodes):
        obs = eval_vec_norm.reset()
        done = [False]
        ep_ret = 0.0
        ep_len = 0
        last_info: Dict[str, Any] = {}

        while not done[0]:
            action, _ = model.predict(obs, deterministic=True)
            obs, rewards, done, infos = eval_vec_norm.step(action)
            ep_ret += float(rewards[0])
            ep_len += 1
            if isinstance(infos, (list, tuple)) and len(infos) > 0 and isinstance(infos[0], dict):
                last_info = infos[0] or last_info

        episode_returns.append(ep_ret)
        episode_lengths.append(ep_len)
        win_rates.append(last_info.get('win_rate', 0.0))
        max_drawdowns.append(last_info.get('max_drawdown', 0.0))

        logger.info(f"Evaluation episode {episode + 1}/{n_episodes}: Return={ep_ret:.2f}, Length={ep_len}")

    return {
        'mean_return': float(np.mean(episode_returns)),
        'std_return': float(np.std(episode_returns)),
        'mean_length': float(np.mean(episode_lengths)),
        'mean_win_rate': float(np.mean(win_rates)),
        'mean_max_drawdown': float(np.mean(max_drawdowns)),
        'sharpe_ratio': float(np.mean(episode_returns) / (np.std(episode_returns) + 1e-8))
    }

def save_model(
    model: PPO,
    vec_env: Optional[VecNormalize],
    env: Optional[gym.Env],
    config: Dict[str, Any],
    policy_kwargs_serializable: Optional[Dict[str, Any]],
    path: str
) -> None:
    """
    Save the model and related files to the specified path.
    """
    if model is None:
        raise ValueError("No model to save")
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)

    model.save(path)

    if isinstance(vec_env, VecNormalize):
        try:
            vec_env.save(os.path.join(os.path.dirname(path) or ".", VECNORM_FILENAME))
        except Exception as e:
            logger.warning(f"Could not save VecNormalize stats on save_model: {e}")

    try:
        base_env = _unwrap_env_for_obs_signature(env)
        obs_sig_fn = getattr(base_env, "obs_signature", None)
        if callable(obs_sig_fn):
            obs_sig = obs_sig_fn()
            with open(os.path.join(os.path.dirname(path) or ".", OBS_CONFIG_FILENAME), "w") as f:
                yaml.safe_dump(obs_sig, f)
    except Exception as e:
        logger.warning(f"Could not save obs_config.yaml: {e}")

    try:
        policy_to_save = policy_kwargs_serializable or _serialize_policy_kwargs(config['ppo'].get('policy_kwargs', {}))
        with open(os.path.join(os.path.dirname(path) or ".", POLICY_KWARGS_FILENAME), "w") as f:
            yaml.safe_dump(policy_to_save, f)
    except Exception as e:
        logger.warning(f"Could not save policy_kwargs.yaml: {e}")

    logger.info(f"Model saved to {path}")

def load_model(
    path: str,
    env: Optional[gym.Env] = None,
    policy_kwargs_serializable: Optional[Dict[str, Any]] = None
) -> Tuple[PPO, Optional[VecNormalize], Optional[gym.Env], Optional[Dict[str, Any]]]:
    """
    Load a model from the specified path.
    
    Returns: (model, vec_env, env, policy_kwargs_serializable)
    """
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model file not found: {path}")

    model_dir = os.path.dirname(path)

    policy_kwargs_path = os.path.join(model_dir, POLICY_KWARGS_FILENAME)
    policy_kwargs_serial = _read_yaml_if_exists(policy_kwargs_path) or {}
    # deserialize into runtime objects
    try:
        policy_kwargs = _deserialize_policy_kwargs(policy_kwargs_serial)
    except Exception as e:
        logger.warning(f"Failed to deserialize policy kwargs from file: {e}")
        policy_kwargs = {}

    policy_kwargs.setdefault("features_extractor_class", TradingFeatureExtractor)
    policy_kwargs.setdefault("features_extractor_kwargs", {"portfolio_features_dim": 5})

    vec_env_loaded = None
    vecnorm_path = os.path.join(model_dir, VECNORM_FILENAME)
    if os.path.exists(vecnorm_path):
        if env is None:
            raise ValueError(
                f"Model appears to use VecNormalize (found {VECNORM_FILENAME}). Provide a real environment with matching observation signature when loading."
            )
        env_vec = _ensure_vec_env(env)

        obs_sig_path = os.path.join(model_dir, OBS_CONFIG_FILENAME)
        expected_sig = _read_yaml_if_exists(obs_sig_path)
        if expected_sig is not None:
            underlying = _unwrap_env_for_obs_signature(env)
            obs_sig_fn = getattr(underlying, "obs_signature", None)
            if callable(obs_sig_fn):
                current_sig = obs_sig_fn()
                logger.info(f"Expected signature: {expected_sig}")
                logger.info(f"Current signature: {current_sig}")
                if int(current_sig.get("obs_dim", -1)) != int(expected_sig.get("obs_dim", -1)):
                    raise RuntimeError(
                        f"Saved model obs_dim={expected_sig.get('obs_dim')} does not match provided env obs_dim={current_sig.get('obs_dim')}. Create an environment with matching features/lookback or retrain the model."
                    )
                
                # Validate feature consistency
                expected_features = expected_sig.get("feature_names", [])
                current_features = current_sig.get("feature_names", [])
                if expected_features != current_features:
                    logger.warning(f"Feature mismatch - Expected: {len(expected_features)}, Current: {len(current_features)}")
                    logger.warning(f"Expected features: {expected_features}")
                    logger.warning(f"Current features: {current_features}")
                    
                    # Check if it's just ordering or missing features
                    missing_features = set(expected_features) - set(current_features)
                    extra_features = set(current_features) - set(expected_features)
                    
                    if missing_features:
                        raise RuntimeError(f"Missing features in current environment: {missing_features}")
                    if extra_features:
                        logger.warning(f"Extra features in current environment (will be ignored): {extra_features}")
        
        vec_env_loaded = VecNormalize.load(vecnorm_path, env_vec)
        vec_env_loaded.training = False
        vec_env_loaded.norm_reward = False
        env = vec_env_loaded
        logger.info(f"Loaded VecNormalize stats from {vecnorm_path}")
    elif env is not None:
        vec_env_loaded = env

    try:
        model = PPO.load(path, env=env, custom_objects={"policy_kwargs": policy_kwargs})
        # record policy kwargs used
        policy_kwargs_used = getattr(model, 'policy_kwargs', policy_kwargs)
        if policy_kwargs_used:
            policy_kwargs_serializable = _serialize_policy_kwargs(policy_kwargs_used)
        logger.info(f"Model loaded from {path}")
        return model, vec_env_loaded, env, policy_kwargs_serializable
    except Exception as e:
        logger.error(f"Failed to load model due to: {e}")
        raise

def create_model(train_data: pd.DataFrame, model_path: Optional[str] = None) -> None:
        env = TradingEnvironment(
            data=train_data,
            config_path="config/model_config.yaml"
        )
        env = Monitor(TradingEnvironment)
        train_vec = DummyVecEnv([lambda: env])
        vec_env = VecNormalize(train_vec, norm_obs=True, norm_reward=True)

        if model_path and os.path.exists(model_path):
            logger.info(f"Loading model from {model_path}")

            model_dir = os.path.dirname(model_path)
            policy_kwargs_path = os.path.join(model_dir, POLICY_KWARGS_FILENAME)
            saved_policy_kwargs = _read_yaml_if_exists(policy_kwargs_path)

            custom_objects = {"learning_rate": model_params.get('learning_rate')}
            if saved_policy_kwargs:
                try:
                    deserialized = _deserialize_policy_kwargs(saved_policy_kwargs)
                    # Ensure feature extractor present
                    deserialized.setdefault("features_extractor_class", TradingFeatureExtractor)
                    deserialized.setdefault("features_extractor_kwargs", {"portfolio_features_dim": 5})
                    custom_objects["policy_kwargs"] = deserialized
                except Exception as e:
                    logger.warning(f"Failed to deserialize saved policy kwargs: {e}")

            try:
                model = PPO.load(model_path, env=vec_env, custom_objects=custom_objects)
                # store policy kwargs used if available
                _policy_kwargs_used = getattr(model, 'policy_kwargs', custom_objects.get('policy_kwargs'))
                if _policy_kwargs_used:
                    _policy_kwargs_serializable = _serialize_policy_kwargs(_policy_kwargs_used)
                logger.info("PPO model loaded successfully for continued training")
                return
            except Exception as e:
                logger.error(f"Failed to load model for continued training: {e}")
                raise

        logger.info("Creating new PPO model")

        policy_kwargs = dict(model_params.get("policy_kwargs", {}) or {})

        act_fn = policy_kwargs.get("activation_fn")
        if isinstance(act_fn, str):
            policy_kwargs["activation_fn"] = _map_activation_fn(act_fn)

        if model_params.get("use_custom_policy", True):
            extractor_kwargs = policy_kwargs.pop('features_extractor_kwargs', {}) or {}
            extractor_kwargs.update({'portfolio_features_dim': 5})
            policy_kwargs.update({
                'features_extractor_class': TradingFeatureExtractor,
                'features_extractor_kwargs': extractor_kwargs
            })

        # store used policy kwargs for later serialization
        _policy_kwargs_used = policy_kwargs
        _policy_kwargs_serializable = _serialize_policy_kwargs(policy_kwargs)

        model = PPO(
            policy='MlpPolicy',
            env=vec_env,
            learning_rate=model_params['learning_rate'],
            n_steps=model_params['n_steps'],
            batch_size=model_params['batch_size'],
            n_epochs=model_params['n_epochs'],
            gamma=model_params['gamma'],
            gae_lambda=model_params['gae_lambda'],
            clip_range=model_params['clip_range'],
            ent_coef=model_params['ent_coef'],
            vf_coef=model_params['vf_coef'],
            max_grad_norm=model_params['max_grad_norm'],
            policy_kwargs=policy_kwargs,
            verbose=1,
            device='auto'
        )

        logger.info("PPO model created successfully")


In [77]:
# Run training
model, env, vec_env, config, training_params, policy_kwargs_serializable = create_model(data, "models/saved_models")

results, updated_model, updated_vec_env, updated_env, updated_policy_kwargs = train(
    model=model,
    vec_env=vec_env,
    env=env,
    config=config,
    training_params=training_params,
    websocket_broadcaster=None,  # Pass your websocket broadcaster if available
    performance_tracker=None,    # Pass your performance tracker if available
    policy_kwargs_serializable=policy_kwargs_serializable,
    train_data=train_data,
    eval_data=eval_data,
    session_id=None  # Pass your session ID if available
)

# Update references to the potentially updated objects
model = updated_model
vec_env = updated_vec_env
env = updated_env
policy_kwargs_serializable = updated_policy_kwargs

logger.info(f"Training completed for {symbol_name}: {results}")

[32m2025-08-28 01:52:30.643[0m | [1mINFO    [0m | [36m__main__[0m:[36m_prepare_data[0m:[36m194[0m - [1mData prepared: 47224 samples, 17 features[0m
[32m2025-08-28 01:52:30.650[0m | [1mINFO    [0m | [36m__main__[0m:[36m_setup_spaces[0m:[36m222[0m - [1mAction space: Discrete(3)[0m
[32m2025-08-28 01:52:30.652[0m | [1mINFO    [0m | [36m__main__[0m:[36m_setup_spaces[0m:[36m223[0m - [1mObservation space: (1705,)[0m
[32m2025-08-28 01:52:30.671[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m85[0m - [1mTrading environment initialized with 47224 data points[0m


AssertionError: Expected env to be a `gymnasium.Env` but got <class 'type'>