In [3]:
import gymnasium as gym  # 替换 gym
import numpy as np
import pandas as pd
from gymnasium import spaces  # 替换 gym.spaces

class StockTradingEnv(gym.Env):
    def __init__(self, df, window_size=10):
        super(StockTradingEnv, self).__init__()

        self.df = df
        self.window_size = window_size  # 观测窗口大小
        self.current_step = window_size  # 从第 `window_size` 天开始，保证有足够的历史数据
        self.initial_balance = 10000  # 初始资金
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_profit = 0

        # 过滤掉日期列
        self.feature_columns = [col for col in df.columns if col != "Date"]
        
        # 观察空间 (window_size, num_features)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(window_size, len(self.feature_columns)), dtype=np.float32
        )

        # 动作空间（0: 持有, 1: 买入, 2: 卖出）
        self.action_space = spaces.Discrete(3)

    def _next_observation(self):
        """ 返回过去 window_size 天的数据 """
        obs = self.df.iloc[self.current_step - self.window_size:self.current_step][self.feature_columns].values
        return obs.astype(np.float32)

    def step(self, action):
        prev_price = self.df.iloc[self.current_step]["Close"]
        self.current_step += 1
        terminated = self.current_step >= len(self.df) - 1
        truncated = False  # 可根据最大步数设定

        reward = 0
        if action == 1:  # 买入
            shares_bought = self.balance // prev_price
            self.shares_held += shares_bought
            self.balance -= shares_bought * prev_price
        elif action == 2 and self.shares_held > 0:  # 卖出
            self.balance += self.shares_held * prev_price
            self.shares_held = 0
            self.total_profit += self.balance - self.initial_balance
            reward = self.total_profit  # 以总收益作为奖励

        obs = self._next_observation()
        return obs, reward, terminated, truncated, {}

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_step = self.window_size  # 确保有足够的历史数据
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_profit = 0
        return self._next_observation(), {}

    def render(self):
        print(f"Step: {self.current_step}, Balance: {self.balance}, Profit: {self.total_profit}")


In [4]:
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
import pandas as pd

# 读取数据
df = pd.read_csv("/home/jesse/Projects/CWP_RL/Train/nvidia_stock_with_indicators.csv", index_col="Date", parse_dates=True)

# 创建环境并使用 make_vec_env 包装
def create_env():
    return StockTradingEnv(df)  # 确保在此创建并传递 df

# 使用 make_vec_env 包装环境
env = make_vec_env(create_env, n_envs=1)

# 训练 DQN 代理
model = DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)

# 保存训练好的模型
model.save("dqn_NVDA")


Using cpu device


In [5]:
aapl_df = pd.read_csv("/home/jesse/Projects/CWP_RL/Test/microsoft_stock_with_indicators.csv")

env = StockTradingEnv(aapl_df)
obs, _ = env.reset()  # `gymnasium` 的 `reset` 现在返回 (obs, info)

for _ in range(len(aapl_df)):
    action, _states = model.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)  # `gymnasium` 需要接收 5 个返回值
    env.render()
    if terminated or truncated:
        break


Step: 11, Balance: 0.02265599999918777, Profit: 0
Step: 12, Balance: 0.02265599999918777, Profit: 0
Step: 13, Balance: 0.02265599999918777, Profit: 0
Step: 14, Balance: 0.02265599999918777, Profit: 0
Step: 15, Balance: 0.02265599999918777, Profit: 0
Step: 16, Balance: 0.02265599999918777, Profit: 0
Step: 17, Balance: 8786.704464, Profit: -1213.2955359999996
Step: 18, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 19, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 20, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 21, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 22, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 23, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 24, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 25, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 26, Balance: 0.31307399999968766, Profit: -1213.2955359999996
Step: 27, Balance: 0.3130739

In [6]:
aapl_df = pd.read_csv("/home/jesse/Projects/CWP_RL/Test/apple_stock_with_indicators.csv")

env = StockTradingEnv(aapl_df)
obs, _ = env.reset()  # `gymnasium` 的 `reset` 现在返回 (obs, info)

for _ in range(len(aapl_df)):
    action, _states = model.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)  # `gymnasium` 需要接收 5 个返回值
    env.render()
    if terminated or truncated:
        break


Step: 11, Balance: 0.08439838889535167, Profit: 0
Step: 12, Balance: 0.08439838889535167, Profit: 0
Step: 13, Balance: 0.08439838889535167, Profit: 0
Step: 14, Balance: 0.08439838889535167, Profit: 0
Step: 15, Balance: 0.08439838889535167, Profit: 0
Step: 16, Balance: 0.08439838889535167, Profit: 0
Step: 17, Balance: 0.08439838889535167, Profit: 0
Step: 18, Balance: 0.08439838889535167, Profit: 0
Step: 19, Balance: 0.08439838889535167, Profit: 0
Step: 20, Balance: 0.08439838889535167, Profit: 0
Step: 21, Balance: 0.08439838889535167, Profit: 0
Step: 22, Balance: 0.08439838889535167, Profit: 0
Step: 23, Balance: 0.08439838889535167, Profit: 0
Step: 24, Balance: 0.08439838889535167, Profit: 0
Step: 25, Balance: 0.08439838889535167, Profit: 0
Step: 26, Balance: 0.08439838889535167, Profit: 0
Step: 27, Balance: 0.08439838889535167, Profit: 0
Step: 28, Balance: 0.08439838889535167, Profit: 0
Step: 29, Balance: 10207.453538537036, Profit: 207.45353853703637
Step: 30, Balance: 0.0233633667048