In [13]:
import pandas as pd
from stable_baselines3 import DQN
import gymnasium as gym
import numpy as np
from gymnasium import spaces



In [14]:
class StockTradingEnv(gym.Env):
    def __init__(self, df):
        super(StockTradingEnv, self).__init__()

        self.df = df
        self.current_step = 0
        self.initial_balance = 10000  # 初始资金
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_profit = 0

        # 观察空间（收盘价, RSI, MACD, 布林中轨）
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)

        # 动作空间（0: 持有, 1: 买入, 2: 卖出）
        self.action_space = spaces.Discrete(3)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)  # 适配 gymnasium 语法
        self.current_step = 0
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_profit = 0
        return self._next_observation(), {}  # 返回 observation 和 info（空字典）

    def _next_observation(self):
        obs = self.df.iloc[self.current_step][["Close", "RSI", "MACD", "Bollinger_Middle"]].values
        return obs.astype(np.float32)

    def step(self, action):
        prev_price = self.df.iloc[self.current_step]["Close"]
        self.current_step += 1
        done = self.current_step >= len(self.df) - 1
        reward = 0

        if action == 1:  # 买入
            shares_bought = self.balance // prev_price
            self.shares_held += shares_bought
            self.balance -= shares_bought * prev_price
        elif action == 2 and self.shares_held > 0:  # 卖出
            self.balance += self.shares_held * prev_price
            self.shares_held = 0
            self.total_profit += self.balance - self.initial_balance
            reward = self.total_profit  # 以总收益作为奖励

        obs = self._next_observation()
        return obs, reward, done, False, {}  # `False` 是 `truncated` 标志，符合 Gymnasium 语法

    def render(self):
        print(f"Step: {self.current_step}, Balance: {self.balance}, Profit: {self.total_profit}")


In [15]:
aapl_df = pd.read_csv("/home/jesse/Projects/RL_Testing/Q_Learning/Testing/AAPL_Preprocessed.csv", index_col="Date", parse_dates=True)

aapl_df.head()

Unnamed: 0_level_0,Unnamed: 0,Open,High,Low,Close,Volume,RSI,MACD,MACD_Signal,MACD_Hist,Bollinger_Upper,Bollinger_Middle,Bollinger_Lower
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
1981-01-30,33,0.127232,0.127232,0.126116,0.126116,46188800,39.642837,-0.000377,0.001789,-0.002166,0.151717,0.140681,0.129644
1981-02-02,34,0.11942,0.11942,0.118862,0.118862,23766400,34.943807,-0.00193,0.001045,-0.002975,0.152763,0.13909,0.125418
1981-02-03,35,0.123326,0.123884,0.123326,0.123326,19152000,39.682098,-0.002769,0.000282,-0.003051,0.153145,0.138058,0.122971
1981-02-04,36,0.12779,0.128348,0.12779,0.12779,27865600,44.069116,-0.003038,-0.000382,-0.002656,0.153294,0.137556,0.121818
1981-02-05,37,0.12779,0.128906,0.12779,0.12779,7929600,44.069116,-0.003215,-0.000949,-0.002266,0.153471,0.137193,0.120915


In [16]:
env = StockTradingEnv(aapl_df)

model = DQN.load("/home/jesse/Projects/RL_Testing/Old_Models/DQN_NVDA_20250206.zip")

obs, _ = env.reset()  # Gymnasium 需要两个返回值，第二个是 info 字典

for _ in range(len(aapl_df)):
    action, _states = model.predict(obs, deterministic=True)  # 确保是确定性策略
    obs, reward, done, _, _ = env.step(action)  # Gymnasium 语法
    env.render()
    if done:
        break


Step: 1, Balance: 10000, Profit: 0
Step: 2, Balance: 10000, Profit: 0
Step: 3, Balance: 10000, Profit: 0
Step: 4, Balance: 0.048815816629939945, Profit: 0
Step: 5, Balance: 0.048815816629939945, Profit: 0
Step: 6, Balance: 0.048815816629939945, Profit: 0
Step: 7, Balance: 9519.682744614793, Profit: -480.3172553852073
Step: 8, Balance: 9519.682744614793, Profit: -480.3172553852073
Step: 9, Balance: 9519.682744614793, Profit: -480.3172553852073
Step: 10, Balance: 9519.682744614793, Profit: -480.3172553852073
Step: 11, Balance: 9519.682744614793, Profit: -480.3172553852073
Step: 12, Balance: 9519.682744614793, Profit: -480.3172553852073
Step: 13, Balance: 0.048815816629939945, Profit: -480.3172553852073
Step: 14, Balance: 8951.957006104269, Profit: -1528.3602492809387
Step: 15, Balance: 8951.957006104269, Profit: -1528.3602492809387
Step: 16, Balance: 8951.957006104269, Profit: -1528.3602492809387
Step: 17, Balance: 8951.957006104269, Profit: -1528.3602492809387
Step: 18, Balance: 0.05993