In [1]:
import gymnasium as gym
import numpy as np
import pandas as pd
from gymnasium import spaces

class StockTradingEnv(gym.Env):
    def __init__(self, df, window_size=10, initial_balance=10000):
        super(StockTradingEnv, self).__init__()

        self.df = df
        self.window_size = window_size
        self.initial_balance = initial_balance
        self.current_step = window_size  # Ensure enough historical data

        # Portfolio variables
        self.balance = initial_balance
        self.shares_held = 0
        self.total_asset_value = initial_balance

        # Filter features (remove Date column)
        self.feature_columns = [col for col in df.columns if col != "Date"]

        # 🔥 Continuous action space: -1 (sell all) to 1 (buy with all balance)
        self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)

        # 🔥 Observation space: (window_size, num_features)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(window_size, len(self.feature_columns)), dtype=np.float32
        )

    def _next_observation(self):
        """Returns past `window_size` days of stock data."""
        obs = self.df.iloc[self.current_step - self.window_size:self.current_step][self.feature_columns].values
        return obs.astype(np.float32)

    def step(self, action):
        """Execute trade and calculate reward based on portfolio value change."""
        prev_asset_value = self.total_asset_value
        current_price = self.df.iloc[self.current_step]["Close"]
        
        # 🔥 Convert action [-1, 1] to trade percentage
        trade_amount = action[0] * self.balance  # Buy if > 0, sell if < 0

        if trade_amount > 0:  # Buy
            shares_bought = trade_amount / current_price
            self.shares_held += shares_bought
            self.balance -= trade_amount
        elif trade_amount < 0:  # Sell
            shares_sold = min(abs(trade_amount) / current_price, self.shares_held)
            self.shares_held -= shares_sold
            self.balance += shares_sold * current_price

        # 🔥 Compute new total asset value
        self.total_asset_value = self.balance + (self.shares_held * current_price)
        reward = self.total_asset_value - prev_asset_value  # 🔥 Incremental reward

        self.current_step += 1
        terminated = self.current_step >= len(self.df) - 1
        truncated = False  # Can modify based on max steps

        obs = self._next_observation()
        return obs, reward, terminated, truncated, {}

    def reset(self, seed=None, options=None):
        """Reset environment to initial state."""
        super().reset(seed=seed)
        self.current_step = self.window_size
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_asset_value = self.initial_balance
        return self._next_observation(), {}

    def render(self):
        print(f"Step: {self.current_step}, Balance: {self.balance}, Portfolio Value: {self.total_asset_value}")


In [2]:
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
import pandas as pd

# 🔥 Load dataset
df = pd.read_csv("/home/jesse/Projects/CWP_RL/Train/nvidia_stock_with_indicators.csv", index_col="Date", parse_dates=True)

# 🔥 Create environment
def create_env():
    return StockTradingEnv(df)  

env = make_vec_env(create_env, n_envs=1)  # SAC supports multi-envs

# 🔥 Train SAC model
model = SAC("MlpPolicy", env, verbose=1, tensorboard_log="./sac_logs/")
model.learn(total_timesteps=500)

# 🔥 Save trained model
model.save("sac_NVDA")


Using cpu device
Logging to ./sac_logs/SAC_6


In [3]:
aapl_df = pd.read_csv("/home/jesse/Projects/CWP_RL/Test/microsoft_stock_with_indicators.csv")

env = StockTradingEnv(aapl_df)
obs, _ = env.reset()  # `gymnasium` 的 `reset` 现在返回 (obs, info)

for _ in range(len(aapl_df)):
    action, _states = model.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)  # `gymnasium` 需要接收 5 个返回值
    env.render()
    if terminated or truncated:
        break


Step: 11, Balance: 0.0, Portfolio Value: 10000.0
Step: 12, Balance: 0.0, Portfolio Value: 9119.384276877017
Step: 13, Balance: 0.0, Portfolio Value: 9197.657222212203
Step: 14, Balance: 0.0, Portfolio Value: 9432.498602268834
Step: 15, Balance: 0.0, Portfolio Value: 9236.793694879794
Step: 16, Balance: 0.0, Portfolio Value: 8845.406424152794
Step: 17, Balance: 0.0, Portfolio Value: 8786.701715151406
Step: 18, Balance: 0.0, Portfolio Value: 8904.111133154183
Step: 19, Balance: 0.0, Portfolio Value: 8845.406424152794
Step: 20, Balance: 0.0, Portfolio Value: 8571.42857142857
Step: 21, Balance: 0.0, Portfolio Value: 8845.406424152794
Step: 22, Balance: 0.0, Portfolio Value: 8649.724060814831
Step: 23, Balance: 0.0, Portfolio Value: 8767.133478817608
Step: 24, Balance: 0.0, Portfolio Value: 9099.81604054322
Step: 25, Balance: 0.0, Portfolio Value: 8962.81584215557
Step: 26, Balance: 0.0, Portfolio Value: 9080.247804209424
Step: 27, Balance: 0.0, Portfolio Value: 9001.974858874239
Step: 28, 

In [4]:
aapl_df = pd.read_csv("/home/jesse/Projects/CWP_RL/Test/apple_stock_with_indicators.csv")

env = StockTradingEnv(aapl_df)
obs, _ = env.reset()  # `gymnasium` 的 `reset` 现在返回 (obs, info)

for _ in range(len(aapl_df)):
    action, _states = model.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)  # `gymnasium` 需要接收 5 个返回值
    env.render()
    if terminated or truncated:
        break


Step: 11, Balance: 10000.0, Portfolio Value: 10000.0
Step: 12, Balance: 0.0, Portfolio Value: 10000.0
Step: 13, Balance: 0.0, Portfolio Value: 10083.328298743754
Step: 14, Balance: 0.0, Portfolio Value: 10374.975119113958
Step: 15, Balance: 0.0, Portfolio Value: 10666.621939484166
Step: 16, Balance: 0.0, Portfolio Value: 10916.604610482495
Step: 17, Balance: 0.0, Portfolio Value: 10833.277424355207
Step: 18, Balance: 0.0, Portfolio Value: 10458.302305241246
Step: 19, Balance: 0.0, Portfolio Value: 10000.0
Step: 20, Balance: 0.0, Portfolio Value: 9958.335850628122
Step: 21, Balance: 0.0, Portfolio Value: 9499.959000082708
Step: 22, Balance: 0.0, Portfolio Value: 10041.664149371876
Step: 23, Balance: 0.0, Portfolio Value: 10083.328298743754
Step: 24, Balance: 0.0, Portfolio Value: 9708.278634325627
Step: 25, Balance: 0.0, Portfolio Value: 9166.648030340624
Step: 26, Balance: 0.0, Portfolio Value: 9583.286186209994
Step: 27, Balance: 0.0, Portfolio Value: 9708.278634325627
Step: 28, Balan