In [1]:
import gymnasium as gym  # 替换 gym
import numpy as np
import pandas as pd
from gymnasium import spaces  # 替换 gym.spaces

class StockTradingEnv(gym.Env):
    def __init__(self, df):
        super(StockTradingEnv, self).__init__()

        self.df = df
        self.current_step = 0
        self.initial_balance = 10000  # 初始资金
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_shares_sold = 0
        self.total_profit = 0

        # 观察空间（收盘价, RSI, MACD, 布林中轨）
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)

        # 动作空间（0: 持有, 1: 买入, 2: 卖出）
        self.action_space = spaces.Discrete(3)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)  # 确保兼容 gymnasium 的 `reset`
        self.current_step = 0
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_profit = 0
        return self._next_observation(), {}  # `reset` 现在需要返回 (obs, info)

    def _next_observation(self):
        obs = self.df.iloc[self.current_step][["Close", "RSI", "MACD", "Bollinger_Middle"]].values
        return obs.astype(np.float32)

    def step(self, action):
        prev_price = self.df.iloc[self.current_step]["Close"]
        self.current_step += 1
        terminated = self.current_step >= len(self.df) - 1  # `gymnasium` 需要 `terminated`
        truncated = False  # 可根据最大步数设定，例如 `truncated = self.current_step > 10000`

        reward = 0
        if action == 1:  # 买入
            shares_bought = self.balance // prev_price
            self.shares_held += shares_bought
            self.balance -= shares_bought * prev_price
        elif action == 2 and self.shares_held > 0:  # 卖出
            self.balance += self.shares_held * prev_price
            self.shares_held = 0
            self.total_profit += self.balance - self.initial_balance
            reward = self.total_profit  # 以总收益作为奖励

        obs = self._next_observation()
        return obs, reward, terminated, truncated, {}  # `gymnasium` 的 `step` 需要返回 5 个值

    def render(self):
        print(f"Step: {self.current_step}, Balance: {self.balance}, Profit: {self.total_profit}")


In [2]:
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym

# 读取数据
df = pd.read_csv("/home/jesse/Projects/RL_Testing/Q_Learning/Training/NVDA_Preprocessed.csv", index_col="Date", parse_dates=True)

# 使用 `gymnasium.make` 创建环境
env = make_vec_env(lambda: StockTradingEnv(df), n_envs=1)

# 训练 DQN 代理
model = DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=800000)

# 保存训练好的模型
model.save("dqn_trading_model")


Using cuda device
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5e+03    |
|    ep_rew_mean      | 5.91e+09 |
|    exploration_rate | 0.763    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 760      |
|    time_elapsed     | 26       |
|    total_timesteps  | 19996    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 4.99e+05 |
|    n_updates        | 4973     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5e+03    |
|    ep_rew_mean      | 5.36e+09 |
|    exploration_rate | 0.525    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 733      |
|    time_elapsed     | 54       |
|    total_timesteps  | 39992    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 9.9e+05  |
| 

In [3]:
aapl_df = pd.read_csv("/home/jesse/Projects/RL_Testing/Q_Learning/Testing/AAPL_Preprocessed.csv")

env = StockTradingEnv(aapl_df)
obs, _ = env.reset()  # `gymnasium` 的 `reset` 现在返回 (obs, info)

for _ in range(len(aapl_df)):
    action, _states = model.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)  # `gymnasium` 需要接收 5 个返回值
    env.render()
    if terminated or truncated:
        break


Step: 1, Balance: 10000, Profit: 0
Step: 2, Balance: 10000, Profit: 0
Step: 3, Balance: 10000, Profit: 0
Step: 4, Balance: 10000, Profit: 0
Step: 5, Balance: 10000, Profit: 0
Step: 6, Balance: 10000, Profit: 0
Step: 7, Balance: 10000, Profit: 0
Step: 8, Balance: 10000, Profit: 0
Step: 9, Balance: 10000, Profit: 0
Step: 10, Balance: 10000, Profit: 0
Step: 11, Balance: 10000, Profit: 0
Step: 12, Balance: 10000, Profit: 0
Step: 13, Balance: 10000, Profit: 0
Step: 14, Balance: 10000, Profit: 0
Step: 15, Balance: 10000, Profit: 0
Step: 16, Balance: 10000, Profit: 0
Step: 17, Balance: 10000, Profit: 0
Step: 18, Balance: 10000, Profit: 0
Step: 19, Balance: 10000, Profit: 0
Step: 20, Balance: 10000, Profit: 0
Step: 21, Balance: 10000, Profit: 0
Step: 22, Balance: 10000, Profit: 0
Step: 23, Balance: 10000, Profit: 0
Step: 24, Balance: 10000, Profit: 0
Step: 25, Balance: 10000, Profit: 0
Step: 26, Balance: 10000, Profit: 0
Step: 27, Balance: 10000, Profit: 0
Step: 28, Balance: 10000, Profit: 0
S