stock_trading_env.py
Module: StockTradingEnv
Purpose: Custom Gymnasium-compatible multi-stock trading environment with modular reward injection.
Design: 
- Multi-episode training architecture
- Sliding window for state representation
- Modular reward design (S_f: sentiment factor, R_f: risk factor, CVaR-aware)
Linkage: Uses config_trading; receives rl_data from FeatureEngineer.
Robustness: Handles invalid prices; logs state/action/reward.

In [1]:
import os
os.chdir('/Users/archy/Projects/finbert_trader/')

In [2]:
import gymnasium as gym
from gymnasium.spaces import Box
import numpy as np
import pandas as pd
import logging

In [3]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [None]:
class StockTradingEnv(gym.Env):
    def __init__(self, config_trading, rl_data, env_type='train'):
        """
        Initialize trading environment.
        Parameters:
        - config_trading: Trading hyperparameters (commission, slippage, gamma, etc)
        - rl_data: dict_list of features, targets, and start_dates
        - env_type: 'train', 'valid', or 'test'
        """
        self.config = config_trading
        self.env_type = env_type
        self.data = rl_data.copy()
        # Core dimensions
        self.symbols = self.config.symbols
        self.state_dim = self.config.state_dim
        self.action_dim = self.config.action_dim
        self.window_size = self.config.window_size
        self.features_all_flatten = self.config.features_all_flatten
        self.features_price_flatten = self.config.features_price_flatten
        self.features_ind_flatten = self.config.features_ind_flatten
        self.features_senti_flatten = self.config.features_senti_flatten
        self.features_risk_flatten = self.config.features_risk_flatten

        self.price_feature_index = self.config.price_feature_index
        self.ind_feature_index = self.config.ind_feature_index
        self.senti_feature_index = self.config.senti_feature_index
        self.risk_feature_index = self.config.risk_feature_index

        # Initialize environment state placeholders
        self.episode_idx = None
        self.trading_df = None
        self.targets = None
        self.terminal_step = None
        self.current_step = None

        self.position = None
        self.cash = None
        self.cost = None
        self.total_asset = None
        self.asset_memory = None
        self.returns_history = []

        # Experiment mode switches for ablation experiment (default True if not set)
        self.use_senti_factor = getattr(self.config, 'use_senti_factor', True)
        self.use_risk_factor = getattr(self.config, 'use_risk_factor', True)

        self.use_senti_features = getattr(self.config, 'use_senti_features', True)
        self.use_risk_features = getattr(self.config, 'use_risk_features', True)
        
        logging.info(f"STE Module - Env Init - Config symbols: {self.symbols}, window_size: {self.window_size}, features_all_flatten len: {len(self.features_all_flatten)}, state_dim: {self.state_dim}, action_dim: {self.action_dim}")
        logging.info(f"STE Module - Env Init - rl_data len: {len(self.data)}, first episode states shape: {self.data[0]['states'].shape if self.data else 'Empty'}")
        logging.debug(f"STE Module - Env Init - Features flatten: price {len(self.features_price_flatten)}, ind {len(self.features_ind_flatten)}, senti {len(self.features_senti_flatten)}, risk {len(self.features_risk_flatten)}")
        logging.debug(f"STE Module - Env Init - Indices: price {self.price_feature_index}, ind {self.ind_feature_index}, senti {self.senti_feature_index}, risk {self.risk_feature_index}")
        # Gym Space Definitions
        self.observation_space = Box(low=-np.inf, high=np.inf,
                                     shape=(self.state_dim,),
                                     dtype=np.float32)
        self.action_space = Box(low=-1, high=1, shape=(self.action_dim,), dtype=np.float32)
        logging.info(f"STE Modul - Env initialized: environment type: {self.env_type}, model={self.config.model}, state_dim={self.state_dim}")

        # Internal State
        self.last_prices = None  # Initial for slippage
        self.reset()

    def reset(self, seed=None, options=None):
        """Initialize multi-episode by randomly selecting a start point."""
        super().reset(seed=seed)
        try:
            self.episode_idx = np.random.randint(len(self.data))
            episode_data = self.data[self.episode_idx]
            self.trading_df = episode_data['states'] # [T, D]
            self.targets = episode_data['targets'] # [T, N]
            self.terminal_step = len(self.trading_df) - 1
            self.current_step = self.window_size
            # Agent state
            self.cash = 1.0
            self.position = np.zeros(self.action_dim, dtype=np.float32)
            self.cost = 0.0
            self.total_asset = 1.0
            self.asset_memory = [self.total_asset]
            self.returns_history = []
            # Initial last prices as the origin price when window set
            self.last_prices = self._get_current_prices()

            info = {'Environment Type': self.env_type,
                    'Episode Index': self.episode_idx,
                    'Episode Length': self.terminal_step + 1,
                    'Start Date': episode_data['start_date']
                    'Targets': self.targets[:5],
                    'Cash': self.cash,
                    'Position': self.position,
                    'Total Asset': self.total_asset,
                    'Last Prices': self.last_prices}

            logging.info(f"STE Module - Env Reset - Episode idx: {self.episode_idx}, trading_df shape: {self.trading_df.shape}, targets shape: {self.targets.shape}, terminal_step: {self.terminal_step}")
            logging.info(f"STE Module - Env Reset - Reset information: {info}")
            return self._get_states(), info
        except Exception as e:
            logging.error(f"STE Module - Env reset error: {e}")
            raise ValueError("Error in environment reset")

    def step(self, actions):
        """Execute trading step and compute modular reward."""
        try:
            try:
                # Clip actions to [-1, 1]
                actions = np.clip(actions, -1, 1).astype(np.float32)
                logging.info(f"STE Module - Env Step - Input actions shape: {actions.shape}, values: {actions}")
            except Exception as e:
                logging.error(f"STE Module - Error in actions clip: {e}")
                raise ValueError("Error in actions clip step")
            # Calculate S_f and R_f dynamically for multi-stock
            # Extract current sentiment and risk from window end (last row)
            current_row = self.trading_df[self.current_step - 1]  # Current after trade
            sentiment_per_stock = current_row[self.senti_feature_index] if self.use_senti_factor else np.zeros(self.action_dim, dtype=np.float32)  # (num_symbols,)
            risk_per_stock = current_row[self.risk_feature_index] if self.use_risk_factor else np.zeros(self.action_dim, dtype=np.float32)  # (num_symbols,)

            Senti_factor = np.ones(self.action_dim, dtype=np.float32)  # Initialize array
            if self.use_senti_factor:
                logging.info(f"STE Module - Env Step - Using sentiment factor, sentiment_per_stock: {sentiment_per_stock}")
                for i in range(self.action_dim):
                    if (sentiment_per_stock[i] > 0 and actions[i] > 0) or (sentiment_per_stock[i] < 0 and actions[i] < 0):
                        Senti_factor[i] = 1 + self.config.infusion_strength
                    elif (sentiment_per_stock[i] > 0 and actions[i] < 0) or (sentiment_per_stock[i] < 0 and actions[i] > 0):
                        Senti_factor[i] = 1 - self.config.infusion_strength

            mod_actions = np.clip(Senti_factor * actions, -1, 1)  # Optional re-clip after infusion

            # Execute trades
            self._execute_trades(mod_actions)
            logging.info(f"STE Module - Env Step - Executed trades, new cash: {self.cash}, position: {self.position}")
            logging.debug(f"STE Module - Env Step - After trades: cash: {self.cash}, position: {self.position}, total_asset: {self.total_asset}, cost: {self.cost}")

            Risk_factor = np.ones(self.action_dim)  # Initialize
            if self.use_risk_factor:
                logging.info(f"STE Module - Env Step - Using risk factor, risk_per_stock: {risk_per_stock}")
                for i in range(self.action_dim):
                    if risk_per_stock[i] > 0:
                        Risk_factor[i] = 1 + getattr(self.config, 'infusion_strength', 0.001) * (risk_per_stock[i] / 2)
                    elif risk_per_stock[i] < 0:
                        Risk_factor[i] = 1 - getattr(self.config, 'infusion_strength', 0.001) * (abs(risk_per_stock[i]) / 2)
                    else:
                        Risk_factor[i] = 1.0

            # Aggregate R_f as weighted mean (e.g., by position) for return adjustment
            weights = np.abs(self.position) / (np.sum(np.abs(self.position)) + 1e-8)  # Portfolio weights

            # Adjust raw_return with R_f (reference FinRL_DeepSeek 4.3)
            raw_return = self._calculate_return() * np.dot(weights, Risk_factor)
            if self.use_senti_factor and self.use_risk_factor:
                logging.info(f"STE Module - Env Step - Raw return: {raw_return}, R_f: {Risk_factor}, S_f: {Senti_factor}, sentiment_per_stock: {sentiment_per_stock}, risk_per_stock: {risk_per_stock}")
            
            pennalty = getattr(self.config, 'cash_penalty_proportion', 0.01)
            reward = np.float32(raw_return - self.cash * pennalty)
            self.returns_history.append(raw_return)
            # CVaR shaping
            if len(self.returns_history) >= getattr(self.config, 'cvar_min_history', 10) and getattr(self.config, 'cvar_factor', 0.05) > 0:
                cvar_alpha = getattr(self.config, 'cvar_alpha', 0.05) 
                returns_array = np.array(self.returns_history, dtype=np.float32)
                var = np.percentile(returns_array, 100 * cvar_alpha)  # VaR at alpha level
                cvar = returns_array[returns_array <= var].mean()    # CVaR: average returns which are lower than VaR
                # Smaller CVaR, Bigger Risk
                reward += np.float32(self.config.cvar_factor * cvar)

            self.asset_memory.append(self.total_asset)
            # Terminal condition
            self.current_step += 1
            done = (self.current_step >= self.terminal_step)
            truncated = False

            info = {'Total Asset': self.total_asset,
                    'Cash': self.cash,
                    'Position': self.position.copy(),
                    'Reward': reward,
                    'Cost': self.cost,
                    'Current Step': self.current_step,
                    'Sentiment Factor': Senti_factor,
                    'Risk Factor': Risk_factor,
                    'Done': done,
                    'Truncated': truncated}
            
            logging.info(f"STE Module - Env Step - Step information: {info}")
            return self._get_states(), reward, done, truncated, info
        except Exception as e:
            logging.error(f"STE Module - Env step error: {e}")
            raise ValueError("Error in environment step")
    
    def _get_states(self):
        try:
            window = self.trading_df[self.current_step - self.window_size : self.current_step]  # (window_size, D)
            price_features = window[:, self.price_feature_index].flatten()  # (window_size, len(price))
            ind_features = window[:, self.ind_feature_index].flatten()  # (window_size, len(ind))
            senti_features = window[:, self.senti_feature_index].flatten() if self.use_senti_features else np.zeros(self.window_size * len(self.senti_feature_index), dtype=np.float32)  # (window_size, len(senti))
            risk_features = window[:, self.risk_feature_index].flatten() if self.use_risk_features else np.zeros(self.window_size * len(self.risk_feature_index), dtype=np.float32)  # (window_size, len(risk))
            logging.debug(f"STE Module - Env _get_states - \
                          Window shape: {window.shape}, \
                          price_feats shape: {price_features.shape}, \
                          ind_feats: {ind_features.shape}, \
                          senti_feats: {senti_features.shape}, \
                          risk_feats: {risk_features.shape}")
            cash_state = np.array([self.cash], dtype=np.float32)
            position_state = np.array([np.sum(self.position)], dtype=np.float32)
            return_state = np.array([self.total_asset / self.asset_memory[0] - 1.0], dtype=np.float32)
            logging.debug(f"STE Module - Env _get_states - Cash: {cash_state}, position_state: {position_state}, return_state: {return_state}")
            # Set temp state for ablation experiment
            state_temp = [price_features, ind_features]
            # Switch sentiment features, default True
            if self.use_senti_features:
                logging.info(f"STE Module - Env _get_states - Introduce Sentiment features")
                state_temp.append(senti_features)
            else:
                logging.info(f"STE Module - Env _get_states - No Sentiment features mode")
                state_temp.append(np.zeros_like(senti_features))
            # Switch risk features, default True
            if self.use_risk_features:
                logging.info(f"STE Module - Env _get_states - Introduce Risk features")
                state_temp.append(risk_features)
            else:
                logging.info(f"STE Module - Env _get_states - No Risk features mode")
                state_temp.append(np.zeros_like(risk_features))

            state_temp.extend([cash_state, position_state, return_state])

            state = np.concatenate(state_temp).astype(np.float32)
            logging.info(f"STE Module - Env _get_states - Final state shape: {state.shape}, expected: {self.state_dim}")
            return state
        except Exception as e:
            logging.error(f"STE Module - Error in state retrieval: {e}")
            raise ValueError("Error in state retrieval")

    def _execute_trades(self, actions):
        """Update portfolio given actions."""
        try:
            # Get price info
            current_prices = self._get_current_prices()
            logging.debug(f"STE Module - Env _execute_trades - Current prices: {current_prices}, actions: {actions}")
            
            # Compute current allocation
            current_allocation = self.position * current_prices
            
            # Calculate target allocation(by weights) 
            weights = actions / (np.sum(np.abs(actions)) + 1e-8)
            target_allocation = self.total_asset * weights
            # Calculate trade diff
            trade_volume = np.abs(target_allocation - current_allocation)
            # Calculate cost
            commission_cost = np.sum(trade_volume) * getattr(self.config, 'commission_rate', 0.005)
            # Calculate slippage cost
            price_diff = np.abs(current_prices - self.last_prices)
            slippage_cost = np.sum(price_diff * np.abs(self.position)) * getattr(self.config, 'slippage_rate', 0.0)

            total_cost = commission_cost + slippage_cost
            self.cost += total_cost

            logging.debug(f"STE Module - Env _execute_trades - Weights: {weights}, desired_allocation: {target_allocation}, total cost: {total_cost}")

            # Update cash and position
            self.position = target_allocation / (current_prices + 1e-8)
            self.cash = self.total_asset - np.sum(self.position * current_prices) - total_cost
            self.total_asset = self.cash + np.sum(self.position * current_prices)
            # Update last prices
            self.last_prices = current_prices.copy()
        except Exception as e:
            logging.error(f"STE Module - Error in trade execution: {e}")
            raise ValueError("Error in trade execution")

    def _calculate_return(self):
        if len(self.asset_memory) == 0:
            return 0.0
        return_value = (self.total_asset / self.asset_memory[-1]) - 1.0
        logging.debug(f"STE Module - Env _calculate_return - Previous asset: {self.asset_memory[-1]}, current: {self.total_asset}")
        return return_value
    
    def _get_current_prices(self):
        """Extract adjusted close prices from window end."""
        last_row = self.trading_df[self.current_step]
        prices = last_row[self.price_feature_index]
        return prices.astype(np.float32)

    def render(self):
        print(f"Step: {self.current_step}, Asset: {self.total_asset:.4f}, Cash: {self.cash:.4f}")

    def close(self):
        pass

In [5]:
from finbert_trader.config_trading import ConfigTrading

In [6]:
%load_ext autoreload
%autoreload 2

In [7]:
class MockConfig:
    symbols = ['GOOGL', 'AAPL']
    window_size = 10
    features_all_flatten = [f'features_{i}_{symbol}' for symbol in symbols for i in range(11)]
    features_price_flatten = [f'Adj_Close_{symbol}' for symbol in symbols]
    features_ind_flatten = [f'ind_{i}_{symbol}' for symbol in symbols for i in range(8)]
    features_senti_flatten = [f'sentiment_score_{symbol}' for symbol in symbols]
    features_risk_flatten = [f'risk_score_{symbol}' for symbol in symbols]
    price_feature_index = [0, 11]
    ind_feature_index = list(range(1,9)) + list(range(12,20))
    senti_feature_index = [9, 20]
    risk_feature_index = [10, 21]
    state_dim = 10 * 22 + 3
    action_dim = 2
    model = 'PPO'
    infusion_strength = 0.001
    cvar_factor = 0.05
    commission_rate = 0.005

In [8]:
mock_config = MockConfig()

In [None]:
mock_rl_data = [{'start_date': '2015-01-01', 'states': np.random.rand(50, 22), 'targets': np.random.rand(50, 2)} for _ in range(3)]
mock_rl_data

[{'states': array([[0.34496029, 0.59180029, 0.74468314, ..., 0.90602972, 0.0238243 ,
          0.28745039],
         [0.19313082, 0.29380271, 0.97068685, ..., 0.45400159, 0.0062079 ,
          0.48001823],
         [0.87179629, 0.25780211, 0.06512885, ..., 0.32704705, 0.97469115,
          0.52808242],
         ...,
         [0.39947717, 0.36839762, 0.63866268, ..., 0.21931985, 0.47801845,
          0.42119129],
         [0.67779586, 0.17453751, 0.57078203, ..., 0.41902098, 0.28817546,
          0.63676759],
         [0.14319532, 0.7657596 , 0.9196992 , ..., 0.36996939, 0.32679211,
          0.1258405 ]], shape=(50, 22)),
  'targets': array([[0.34513685, 0.49292539],
         [0.56943261, 0.67754889],
         [0.23363214, 0.23487153],
         [0.04484083, 0.74151533],
         [0.80970584, 0.61727011],
         [0.38654698, 0.6645617 ],
         [0.9468566 , 0.11458474],
         [0.47659322, 0.05330213],
         [0.64411917, 0.71827316],
         [0.61322567, 0.37207209],
         

In [10]:
env = StockTradingEnv(mock_config, mock_rl_data, env_type='test')

2025-08-09 21:21:18,670 - INFO - STE Module - Env Init - Config symbols: ['GOOGL', 'AAPL'], window_size: 10, features_all_flatten len: 22, state_dim: 223, action_dim: 2
2025-08-09 21:21:18,671 - INFO - STE Module - Env Init - rl_data len: 3, first episode states shape: (50, 22)
2025-08-09 21:21:18,672 - INFO - STE Modul - Env initialized: environment type: test, model=PPO, state_dim=223
2025-08-09 21:21:18,672 - INFO - STE Module - Env Reset - Episode idx: 2, trading_df shape: (50, 22), targets shape: (50, 2), terminal_step: 49
2025-08-09 21:21:18,673 - INFO - STE Module - Env Reset - Reset information: {'Environment Type': 'test', 'Episode Index': 2, 'Episode Length': 50, 'Targets': array([[6.70854851e-01, 6.97468806e-01],
       [6.97676262e-01, 4.90803822e-01],
       [6.89973399e-01, 2.74050894e-01],
       [5.48529384e-01, 3.54137669e-01],
       [5.13199242e-05, 4.89342825e-01]]), 'Cash': 1.0, 'Position': array([0., 0.], dtype=float32), 'Total Asset': 1.0, 'Last Prices': array([0

In [11]:
obs, info = env.reset()

2025-08-09 21:21:43,774 - INFO - STE Module - Env Reset - Episode idx: 2, trading_df shape: (50, 22), targets shape: (50, 2), terminal_step: 49
2025-08-09 21:21:43,775 - INFO - STE Module - Env Reset - Reset information: {'Environment Type': 'test', 'Episode Index': 2, 'Episode Length': 50, 'Targets': array([[6.70854851e-01, 6.97468806e-01],
       [6.97676262e-01, 4.90803822e-01],
       [6.89973399e-01, 2.74050894e-01],
       [5.48529384e-01, 3.54137669e-01],
       [5.13199242e-05, 4.89342825e-01]]), 'Cash': 1.0, 'Position': array([0., 0.], dtype=float32), 'Total Asset': 1.0, 'Last Prices': array([0.3986088 , 0.29088143], dtype=float32)}
2025-08-09 21:21:43,776 - INFO - STE Module - Env _get_states - Introduce Sentiment features
2025-08-09 21:21:43,776 - INFO - STE Module - Env _get_states - Introduce Risk features
2025-08-09 21:21:43,776 - INFO - STE Module - Env _get_states - Final state shape: (223,), expected: 223


In [12]:
actions = np.random.rand(env.action_dim) - 0.5

In [13]:
next_obs, reward, done, truncated, info = env.step(actions)

2025-08-09 21:22:00,532 - INFO - STE Module - Env Step - Input actions shape: (2,), values: [0.32876447 0.24625897]
2025-08-09 21:22:00,533 - INFO - STE Module - Env Step - Using sentiment factor, sentiment_per_stock: [0.74856026 0.55629002]
2025-08-09 21:22:00,534 - INFO - STE Module - Env Step - Executed trades, new cash: -0.004999999888241291, position: [1.4343411 1.4722804]
2025-08-09 21:22:00,536 - INFO - STE Module - Env Step - Using risk factor, risk_per_stock: [0.32756823 0.40705664]
2025-08-09 21:22:00,536 - INFO - STE Module - Env Step - Raw return: -0.005000914808759473, R_f: [1.00016378 1.00020353], S_f: [1.001 1.001], sentiment_per_stock: [0.74856026 0.55629002], risk_per_stock: [0.32756823 0.40705664]
2025-08-09 21:22:00,537 - INFO - STE Module - Env Step - Step information: {'Total Asset': np.float32(0.995), 'Cash': np.float32(-0.005), 'Position': array([1.4343411, 1.4722804], dtype=float32), 'Reward': np.float64(-0.004950914810022579), 'Cost': np.float32(0.005), 'Curren

In [14]:
env.render()

Step: 11, Asset: 0.9950, Cash: -0.0050


In [None]:
env.close()