In [4]:
import pandas as pd

file_path = "/home/jesse/Projects/RL_Testing/LSTM_Attention/all_stocks.csv"
df = pd.read_csv(file_path)

df.head()

Unnamed: 0,Date,BP_Open,BP_High,BP_Low,BP_Close,BP_Volume,BP_RSI_7,BP_RSI_14,BP_MACD,BP_Signal_Line,...,WFC_MACD,WFC_Signal_Line,WFC_BB_upper,WFC_BB_lower,WFC_High-Low,WFC_High-Close,WFC_Low-Close,WFC_True_Range,WFC_ATR_14,WFC_Williams_%R
0,2012-06-15,19.459518,19.709501,19.4056,19.709501,7102900,81.962132,62.93712,-0.075738,-0.298102,...,-0.120055,-0.249732,22.782834,20.964491,0.529357,0.362185,0.167172,0.529357,0.5811,-3.63632
1,2012-06-18,19.400699,19.479125,19.253648,19.371288,4552900,64.606711,59.288562,-0.035472,-0.245576,...,-0.06158,-0.212101,22.878807,20.97439,0.369157,0.195031,0.174126,0.369157,0.587568,-9.215037
2,2012-06-19,19.704596,19.836941,19.616366,19.724203,6269700,77.436056,72.2992,0.024632,-0.191535,...,0.012717,-0.167138,23.037969,20.923886,0.501499,0.529361,0.027862,0.529361,0.589061,-7.6024
3,2012-06-20,19.729107,19.792829,19.444811,19.601665,4938300,71.464775,75.542735,0.061666,-0.140895,...,0.062448,-0.12122,23.147359,20.8939,0.585081,0.139308,0.445773,0.585081,0.582095,-11.988209
4,2012-06-21,19.263448,19.371285,18.562513,18.587021,10639600,39.622607,56.682005,0.009038,-0.110908,...,0.074584,-0.082059,23.190337,20.892712,0.605976,0.229846,0.37613,0.605976,0.523388,-25.731185


In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pandas as pd

class TradingEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 30}
    
    def __init__(self, df, window_size=10, initial_balance=10000, max_drawdown=0.2, risk_free_rate=0.01, transaction_cost=0.001):
        super(TradingEnv, self).__init__()
        
        self.df = df
        self.window_size = window_size
        self.initial_balance = initial_balance
        self.max_drawdown = max_drawdown  # 最大回撤阈值
        self.risk_free_rate = risk_free_rate  # 无风险利率，用于夏普比率计算
        
         # 添加交易成本
        self.transaction_cost = transaction_cost  # 每笔交易的手续费比例
        
        
        # 定义状态空间 (window_size 天的价格 + 账户状态)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(window_size + 2,), dtype=np.float32)
        self.action_space = spaces.Discrete(3)
        
        self.reset()
    
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        self.current_step = self.window_size
        self.balance = self.initial_balance
        self.shares_held = 0
        self.total_profit = 0
        self.peak_value = self.initial_balance  # 初始的峰值（用于回撤计算）
        self.done = False
        self.history = []  # 用于记录每个时刻的资产值
        
        return self._next_observation(), {}
    
    def _next_observation(self):
        obs = self.df.iloc[self.current_step - self.window_size:self.current_step].values.flatten()
        return np.append(obs, [self.balance, self.shares_held])
    
    def step(self, action):
        if self.done:
            return self._next_observation(), 0, True, False, {}
        
        current_price = self.df.iloc[self.current_step]['Close']
        reward = 0
        
        prev_balance = self.balance + self.shares_held * current_price
        daily_return = 0
        
        # 当前价格
        current_price = self.df.iloc[self.current_step]['Close']
        reward = 0
        
        # 执行买入或卖出操作并扣除交易成本
        if action == 1:  # 买入
            if self.balance >= current_price:
                cost = current_price * self.transaction_cost  # 计算交易成本
                self.shares_held += 1
                self.balance -= (current_price + cost)  # 扣除手续费后的余额
        elif action == 2:  # 卖出
            if self.shares_held > 0:
                cost = current_price * self.transaction_cost  # 计算交易成本
                self.shares_held -= 1
                self.balance += (current_price - cost)  # 扣除手续费后的余额
        
        # 计算总资产
        total_value = self.balance + (self.shares_held * current_price)
        daily_return = (total_value - prev_balance) / prev_balance
        
        self.total_profit = total_value - self.initial_balance
        
        # 累计利润
        self.total_profit = total_value - self.initial_balance
        
        # 计算回撤
        if self.current_step > self.window_size:
            self.peak_value = max(self.peak_value, total_value)
            drawdown = (self.peak_value - total_value) / self.peak_value  # 回撤计算
        else:
            drawdown = 0  # 初期没有回撤

        
        # 计算夏普比率
        if len(self.history) > 1:
            sharpe_ratio = (self.total_profit - self.risk_free_rate) / (np.std(self.history) if np.std(self.history) > 0 else 1)
        else:
            sharpe_ratio = 0  # 初期没有足够数据时，给定默认值
        
        # 奖励函数
        reward += self.total_profit * 0.3  # 累计利润奖励
        reward += daily_return * 0.3  # 每日收益奖励
        
        # 回撤惩罚
        if drawdown > self.max_drawdown:
            reward -= (drawdown - self.max_drawdown) * 10  # 用回撤幅度来乘以惩罚系数
        
        # 收益风险比奖励（简化的夏普比率）
        reward += sharpe_ratio * 0.2  # 收益风险比奖励
        
        # 将当前资产加入历史记录，用于计算夏普比率
        self.history.append(total_value)
        
        # 更新状态
        self.current_step += 1
        if self.current_step >= len(self.df) - 1:
            self.done = True
        
        return self._next_observation(), reward, self.done, False, {}
    
    def render(self, mode="human"):
        print(f'Step: {self.current_step}, Balance: {self.balance}, Shares: {self.shares_held}, Profit: {self.total_profit}, Drawdown: {self.peak_value - (self.balance + self.shares_held * self.df.iloc[self.current_step]["Close"])}')



In [6]:
import pandas as pd
import numpy as np

# 生成假数据
dates = pd.date_range(start="2020-01-01", periods=100)
data = pd.DataFrame({
    "Close": np.random.uniform(50, 150, size=100),
}, index=dates)

# 创建环境
env = TradingEnv(df=data)

# 运行测试
obs, _ = env.reset()
for _ in range(10):
    action = env.action_space.sample()
    obs, reward, done, _, _ = env.step(action)
    env.render()


Step: 11, Balance: 9856.493597563234, Shares: 1, Profit: 0.0, Drawdown: 31.759079812678465
Step: 12, Balance: 9968.240920187322, Shares: 0, Profit: -31.759079812678465, Drawdown: 31.759079812678465
Step: 13, Balance: 9968.240920187322, Shares: 0, Profit: -31.759079812678465, Drawdown: 31.759079812678465
Step: 14, Balance: 9824.470902239711, Shares: 1, Profit: -31.759079812678465, Drawdown: 44.12282790839163
Step: 15, Balance: 9955.877172091608, Shares: 0, Profit: -44.12282790839163, Drawdown: 44.12282790839163
Step: 16, Balance: 9955.877172091608, Shares: 0, Profit: -44.12282790839163, Drawdown: 44.12282790839163
Step: 17, Balance: 9955.877172091608, Shares: 0, Profit: -44.12282790839163, Drawdown: 44.12282790839163
Step: 18, Balance: 9955.877172091608, Shares: 0, Profit: -44.12282790839163, Drawdown: 44.12282790839163
Step: 19, Balance: 9955.877172091608, Shares: 0, Profit: -44.12282790839163, Drawdown: 44.12282790839163
Step: 20, Balance: 9955.877172091608, Shares: 0, Profit: -44.122