In [1]:
import gymnasium as gym
import numpy as np
import pandas as pd
from gymnasium import spaces

class StockTradingEnv(gym.Env):
    def __init__(self, df, window_size=10, initial_balance=10000, trading_fee=0.001):
        super(StockTradingEnv, self).__init__()

        self.df = df
        self.window_size = window_size
        self.initial_balance = initial_balance
        self.trading_fee = trading_fee
        self.current_step = window_size  # Ensure enough historical data

        # Portfolio variables
        self.balance = initial_balance
        self.shares_held = 0
        self.total_asset_value = initial_balance

        # Filter features (remove Date column)
        self.feature_columns = [col for col in df.columns if col != "Date"]

        # 🔥 Continuous action space: -1 (sell all) to 1 (buy with 75% balance)
        self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)

        # 🔥 Observation space: (window_size, num_features)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(window_size, len(self.feature_columns)), dtype=np.float32
        )

    def _next_observation(self):
        """Returns past `window_size` days of stock data."""
        obs = self.df.iloc[self.current_step - self.window_size:self.current_step][self.feature_columns].values
        return obs.astype(np.float32)

    def step(self, action):
        """Execute trade and calculate reward based on portfolio value change."""
        prev_asset_value = self.total_asset_value
        current_price = self.df.iloc[self.current_step]["Close"]
        
        # 🔥 Convert action [-1, 1] to trade percentage (sell all to buy 75% max)
        trade_amount = (action[0] / 2 + 0.5) * self.balance  

        if trade_amount > 0:  # Buy
            shares_bought = trade_amount / current_price
            self.shares_held += shares_bought
            self.balance -= trade_amount
        elif trade_amount < 0:  # Sell
            shares_sold = min(abs(trade_amount) / current_price, self.shares_held)
            self.shares_held -= shares_sold
            self.balance += shares_sold * current_price

        # 🔥 Apply trading fee (0.1%)
        self.balance -= abs(trade_amount) * self.trading_fee

        # 🔥 Compute new total asset value
        self.total_asset_value = self.balance + (self.shares_held * current_price)

        # 🔥 Reward: normalized return with time penalty
        reward = (self.total_asset_value - prev_asset_value) / prev_asset_value - 0.001  

        self.current_step += 1
        terminated = self.current_step >= len(self.df) - 1
        truncated = False  # Can modify based on max steps

        obs = self._next_observation()
        return obs, reward, terminated, truncated, {}

    def reset(self, seed=None, options=None):
        """Reset environment to initial state."""
        super().reset(seed=seed)
        self.current_step = self.window_size
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_asset_value = self.initial_balance
        return self._next_observation(), {}

    def render(self):
        print(f"Step: {self.current_step}, Balance: {self.balance}, Portfolio Value: {self.total_asset_value}")


In [2]:
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
import pandas as pd

# 🔥 Load dataset
df = pd.read_csv("/home/jesse/Projects/CWP_RL/Train/nvidia_stock_with_indicators.csv", index_col="Date", parse_dates=True)

# 🔥 Create environment
def create_env():
    return StockTradingEnv(df)  

env = make_vec_env(create_env, n_envs=1)  # SAC supports multi-envs

# 🔥 Train SAC model
model = SAC("MlpPolicy", env, verbose=1, tensorboard_log="./sac_logs/")
model.learn(total_timesteps=1000)

# 🔥 Save trained model
model.save("sac_NVDA")


Using cpu device
Logging to ./sac_logs/SAC_9


In [3]:
aapl_df = pd.read_csv("/home/jesse/Projects/CWP_RL/Test/microsoft_stock_with_indicators.csv")

env = StockTradingEnv(aapl_df)
obs, _ = env.reset()  # `gymnasium` 的 `reset` 现在返回 (obs, info)

for _ in range(len(aapl_df)):
    action, _states = model.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)  # `gymnasium` 需要接收 5 个返回值
    env.render()
    if terminated or truncated:
        break


Step: 11, Balance: -10.0, Portfolio Value: 9990.0
Step: 12, Balance: -0.010000000707805157, Portfolio Value: 9109.37427687631
Step: 13, Balance: -1.0000000707805157e-05, Portfolio Value: 9187.561380818759
Step: 14, Balance: -1.0000000707805158e-08, Portfolio Value: 9422.144986638761
Step: 15, Balance: -1.0000000707805158e-11, Portfolio Value: 9226.654895436974
Step: 16, Balance: -1.0000000707805158e-14, Portfolio Value: 8835.697232339357
Step: 17, Balance: -1.0000000707805158e-17, Portfolio Value: 8777.056960770542
Step: 18, Balance: -1.0000000707805158e-20, Portfolio Value: 8894.337503908171
Step: 19, Balance: -1.0000000707805158e-23, Portfolio Value: 8835.697232339357
Step: 20, Balance: -1.0000000707805158e-26, Portfolio Value: 8562.02011237937
Step: 21, Balance: -1.0000000707805158e-29, Portfolio Value: 8835.697232339357
Step: 22, Balance: -1.000000070780656e-32, Portfolio Value: 8640.229660443309
Step: 23, Balance: -1.0000000707806561e-35, Portfolio Value: 8757.510203580938
Step: 2

In [4]:
aapl_df = pd.read_csv("/home/jesse/Projects/CWP_RL/Test/apple_stock_with_indicators.csv")

env = StockTradingEnv(aapl_df)
obs, _ = env.reset()  # `gymnasium` 的 `reset` 现在返回 (obs, info)

for _ in range(len(aapl_df)):
    action, _states = model.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)  # `gymnasium` 需要接收 5 个返回值
    env.render()
    if terminated or truncated:
        break


Step: 11, Balance: -10.0, Portfolio Value: 9990.0
Step: 12, Balance: -0.010000000707805157, Portfolio Value: 9948.498720514026
Step: 13, Balance: -1.0000000707805157e-05, Portfolio Value: 10031.397941185814
Step: 14, Balance: -1.0000000707805158e-08, Portfolio Value: 10321.54274551547
Step: 15, Balance: -1.0000000707805158e-11, Portfolio Value: 10611.687549574015
Step: 16, Balance: -1.0000000707805158e-14, Portfolio Value: 10860.382779656436
Step: 17, Balance: -1.0000000707805158e-17, Portfolio Value: 10777.484738591082
Step: 18, Balance: -1.0000000707805158e-20, Portfolio Value: 10404.44078658109
Step: 19, Balance: -1.0000000707805158e-23, Portfolio Value: 9948.498793505745
Step: 20, Balance: -1.0000000707805158e-26, Portfolio Value: 9907.049219529888
Step: 21, Balance: -1.0000000707805158e-29, Portfolio Value: 9451.033065067686
Step: 22, Balance: -1.0000000707805159e-32, Portfolio Value: 9989.9483674816
Step: 23, Balance: -1.0000000707805159e-35, Portfolio Value: 10031.397941457459
S