In [36]:
import gym
from gym import spaces
import numpy as np
import pandas as pd
import pickle
import json
from stable_baselines3 import SAC

# ========= 1. 数据预处理 =========
with open("trade_signals.json", "r", encoding="utf-8") as f:
    signals_data = json.load(f)

filtered_signals = [
    d for d in signals_data
    if 'Signal' in d and 'Pattern' in d and 'Take Profit' in d and 'Stop Loss' in d
]

# 统计所有出现过的类别
all_signals = sorted(list(set(d['Signal'] for d in filtered_signals)))
all_patterns = sorted(list(set(d['Pattern'] for d in filtered_signals)))
all_codes = sorted(list(set(d['Code'] for d in filtered_signals)))

# 增加"other"类别，必须放在最后一位
all_signals.append('other')
all_patterns.append('other')
all_codes.append('other')

# 保存类别，推理时要用
with open("env_categories.pkl", "wb") as f:
    pickle.dump({
        "signal_types": all_signals,
        "patterns": all_patterns,
        "codes": all_codes
    }, f)

# ========= 2. 环境定义 =========
def get_index_with_other(cat_list, item):
    try:
        return cat_list.index(item)
    except ValueError:
        return len(cat_list) - 1  # "other"

class TradingEnv(gym.Env):
    def __init__(self, data, signal_types, patterns, codes,
                 tp_range=(0.9, 1.25), sl_range=(0.95, 1.01)):
        super().__init__()
        self.data = data
        self.signal_types = signal_types
        self.patterns = patterns
        self.codes = codes
        self.n_signals = len(signal_types)
        self.n_patterns = len(patterns)
        self.n_codes = len(codes)
        self.numeric_features = 8
        self.current_step = 0

        obs_len = self.n_signals + self.n_patterns + self.n_codes + self.numeric_features
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf,
            shape=(obs_len,),
            dtype=np.float32
        )
        self.action_space = spaces.Box(
            low=np.array([tp_range[0], sl_range[0]], dtype=np.float32),
            high=np.array([tp_range[1], sl_range[1]], dtype=np.float32),
            dtype=np.float32
        )

    def reset(self):
        self.current_step = 0
        return self._get_obs()

    def _get_obs(self):
        if self.current_step >= len(self.data):
            return np.zeros(self.observation_space.shape, dtype=np.float32)
        d = self.data[self.current_step]
        # one-hot，未知类别归到"other"位
        signal_onehot = np.zeros(self.n_signals, dtype=np.float32)
        pattern_onehot = np.zeros(self.n_patterns, dtype=np.float32)
        code_onehot = np.zeros(self.n_codes, dtype=np.float32)
        si = get_index_with_other(self.signal_types, d['Signal'])
        pi = get_index_with_other(self.patterns, d['Pattern'])
        ci = get_index_with_other(self.codes, d['Code'])
        signal_onehot[si] = 1
        pattern_onehot[pi] = 1
        code_onehot[ci] = 1

        entry_price = float(d.get('Entry Price', 1))
        volume = float(d.get('Volume', 0)) / entry_price
        volatility = float(d.get('Volatility', 0)) / entry_price
        tp = float(d.get('Take Profit', 0)) / entry_price
        sl = float(d.get('Stop Loss', 0)) / entry_price
        atr = float(d.get('ATR', 0)) / entry_price

        try:
            dt = pd.to_datetime(d['Timestamp'])
            hour_of_day = dt.hour / 23.0
            day_of_week = dt.weekday() / 6.0
        except Exception:
            hour_of_day = 0.0
            day_of_week = 0.0

        obs = np.concatenate([
            signal_onehot, pattern_onehot, code_onehot,
            [volume, volatility, tp, sl, hour_of_day, day_of_week, 1.0, atr]
        ])
        if obs.shape[0] != self.observation_space.shape[0]:
            raise ValueError(f"obs shape {obs.shape} != {self.observation_space.shape}")
        return obs

    def step(self, action):
        tp, sl = action
        d = self.data[self.current_step]
        entry_price = float(d.get('Entry Price', 1))
        best_tp = float(d['Take Profit']) / entry_price
        best_sl = float(d['Stop Loss']) / entry_price
        diff = abs(tp - best_tp) + abs(sl - best_sl)
        reward = 1.0 if diff < 0.01 else -1.0
        self.current_step += 1
        done = self.current_step >= len(self.data)
        obs_dim = self.observation_space.shape[0]
        obs = self._get_obs() if not done else np.zeros(obs_dim, dtype=np.float32)
        info = {}
        return obs, reward, done, info

    def render(self, mode='human'):
        pass

    def close(self):
        pass

# ========= 3. 训练 =========
env = TradingEnv(
    filtered_signals,
    signal_types=all_signals,
    patterns=all_patterns,
    codes=all_codes
)
model = SAC('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=200_000)
model.save("sac_trading_agent")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 6.12e+03 |
|    ep_rew_mean     | -1.6e+03 |
| time/              |          |
|    episodes        | 4        |
|    fps             | 18       |
|    time_elapsed    | 1341     |
|    total_timesteps | 24480    |
| train/             |          |
|    actor_loss      | 38.3     |
|    critic_loss     | 0.547    |
|    ent_coef        | 0.317    |
|    ent_coef_loss   | -0.013   |
|    learning_rate   | 0.0003   |
|    n_updates       | 24379    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 6.12e+03 |
|    ep_rew_mean     | -107     |
| time/              |          |
|    episodes        | 8        |
|    fps             | 17       |
|    time_elapsed    | 2744     |
|    total_timesteps | 48960    |
| train/             |

In [50]:
import gym
from gym import spaces
import numpy as np
import pandas as pd
import pickle
import json
from stable_baselines3 import SAC
import pickle

# ========= 1. 数据预处理 =========
with open("test_trade_signals.json", "r", encoding="utf-8") as f:
    signals_data = json.load(f)

filtered_signals = [
    d for d in signals_data
    if 'Signal' in d and 'Pattern' in d and 'Take Profit' in d and 'Stop Loss' in d
]

# 统计所有出现过的类别
all_signals = sorted(list(set(d['Signal'] for d in filtered_signals)))
all_patterns = sorted(list(set(d['Pattern'] for d in filtered_signals)))
all_codes = sorted(list(set(d['Code'] for d in filtered_signals)))

# 增加"other"类别，必须放在最后一位
all_signals.append('other')
all_patterns.append('other')
all_codes.append('other')

# 保存类别，推理时要用
with open("env_categories.pkl", "wb") as f:
    pickle.dump({
        "signal_types": all_signals,
        "patterns": all_patterns,
        "codes": all_codes
    }, f)

# ========= 2. 环境定义 =========
def get_index_with_other(cat_list, item):
    try:
        return cat_list.index(item)
    except ValueError:
        return len(cat_list) - 1  # "other"

class TradingEnv(gym.Env):
    def __init__(self, data, signal_types, patterns, codes,
                 tp_range=(0.9, 1.25), sl_range=(0.95, 1.01)):
        super().__init__()
        self.data = data
        self.signal_types = signal_types
        self.patterns = patterns
        self.codes = codes
        self.n_signals = len(signal_types)
        self.n_patterns = len(patterns)
        self.n_codes = len(codes)
        self.numeric_features = 8
        self.current_step = 0

        obs_len = self.n_signals + self.n_patterns + self.n_codes + self.numeric_features
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf,
            shape=(obs_len,),
            dtype=np.float32
        )
        self.action_space = spaces.Box(
            low=np.array([tp_range[0], sl_range[0]], dtype=np.float32),
            high=np.array([tp_range[1], sl_range[1]], dtype=np.float32),
            dtype=np.float32
        )

    def reset(self):
        self.current_step = 0
        return self._get_obs()

    def _get_obs(self):
        if self.current_step >= len(self.data):
            return np.zeros(self.observation_space.shape, dtype=np.float32)
        d = self.data[self.current_step]
        # one-hot，未知类别归到"other"位
        signal_onehot = np.zeros(self.n_signals, dtype=np.float32)
        pattern_onehot = np.zeros(self.n_patterns, dtype=np.float32)
        code_onehot = np.zeros(self.n_codes, dtype=np.float32)
        si = get_index_with_other(self.signal_types, d['Signal'])
        pi = get_index_with_other(self.patterns, d['Pattern'])
        ci = get_index_with_other(self.codes, d['Code'])
        signal_onehot[si] = 1
        pattern_onehot[pi] = 1
        code_onehot[ci] = 1

        entry_price = float(d.get('Entry Price', 1))
        volume = float(d.get('Volume', 0)) / entry_price
        volatility = float(d.get('Volatility', 0)) / entry_price
        tp = float(d.get('Take Profit', 0)) / entry_price
        sl = float(d.get('Stop Loss', 0)) / entry_price
        atr = float(d.get('ATR', 0)) / entry_price

        try:
            dt = pd.to_datetime(d['Timestamp'])
            hour_of_day = dt.hour / 23.0
            day_of_week = dt.weekday() / 6.0
        except Exception:
            hour_of_day = 0.0
            day_of_week = 0.0

        obs = np.concatenate([
            signal_onehot, pattern_onehot, code_onehot,
            [volume, volatility, tp, sl, hour_of_day, day_of_week, 1.0, atr]
        ])
        if obs.shape[0] != self.observation_space.shape[0]:
            raise ValueError(f"obs shape {obs.shape} != {self.observation_space.shape}")
        return obs

    def step(self, action):
        tp, sl = action
        d = self.data[self.current_step]
        entry_price = float(d.get('Entry Price', 1))
        best_tp = float(d['Take Profit']) / entry_price
        best_sl = float(d['Stop Loss']) / entry_price
        diff = abs(tp - best_tp) + abs(sl - best_sl)
        reward = 1.0 if diff < 0.01 else -1.0
        self.current_step += 1
        done = self.current_step >= len(self.data)
        obs_dim = self.observation_space.shape[0]
        obs = self._get_obs() if not done else np.zeros(obs_dim, dtype=np.float32)
        info = {}
        return obs, reward, done, info

    def render(self, mode='human'):
        pass

    def close(self):
        pass

with open("env_categories.pkl", "rb") as f:
    cats = pickle.load(f)

new_env = TradingEnv(
    filtered_signals,
    signal_types=cats['signal_types'],
    patterns=cats['patterns'],
    codes=cats['codes']
)
model = SAC.load("sac_trading_agent", env=new_env)
obs = new_env.reset()
done = False
while not done:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, info = new_env.step(action)
    print(f"Predicted Take Profit: {action[0]:.4f}, Predicted Stop Loss: {action[1]:.4f}")
    print("Number of trading signals in the dataset:", len(filtered_signals))

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


ValueError: Observation spaces do not match: Box(-inf, inf, (29,), float32) != Box(-inf, inf, (17,), float32)

In [82]:
import gym
from gym import spaces
import numpy as np
import pandas as pd
import pickle
import json
from stable_baselines3 import SAC

class TradingEnv(gym.Env):
    def __init__(self, data, categories_dict=None, tp_range=(1.0, 1.25), sl_range=(0.95, 0.99)):
        super().__init__()
        self.data = data
        
        if categories_dict is None:
            # 训练时收集所有类别
            self.signal_types = sorted(list(set(d['Signal'] for d in data)))
            self.patterns = sorted(list(set(d['Pattern'] for d in data)))
            self.codes = sorted(list(set(d['Code'] for d in data)))
            
            # 增加"other"类别
            self.signal_types.append('other')
            self.patterns.append('other')
            self.codes.append('other')
            
            # 保存类别字典
            self.categories_dict = {
                'signal_types': self.signal_types,
                'patterns': self.patterns,
                'codes': self.codes
            }
        else:
            # 使用提供的类别字典
            self.categories_dict = categories_dict
            self.signal_types = categories_dict['signal_types']
            self.patterns = categories_dict['patterns']
            self.codes = categories_dict['codes']
        
        self.n_signals = len(self.signal_types)
        self.n_patterns = len(self.patterns)
        self.n_codes = len(self.codes)
        self.numeric_features = 8
        self.current_step = 0
        
        obs_len = self.n_signals + self.n_patterns + self.n_codes + self.numeric_features
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf,
            shape=(obs_len,),
            dtype=np.float32
        )
        self.action_space = spaces.Box(
            low=np.array([tp_range[0], sl_range[0]], dtype=np.float32),
            high=np.array([tp_range[1], sl_range[1]], dtype=np.float32),
            dtype=np.float32
        )
    
    def reset(self):
        self.current_step = 0
        return self._get_obs()
    
    def _get_obs(self):
        if self.current_step >= len(self.data):
            return np.zeros(self.observation_space.shape, dtype=np.float32)
        
        d = self.data[self.current_step]
        # one-hot编码
        signal_onehot = np.zeros(self.n_signals, dtype=np.float32)
        pattern_onehot = np.zeros(self.n_patterns, dtype=np.float32)
        code_onehot = np.zeros(self.n_codes, dtype=np.float32)
        
        # 使用索引，如果类别不存在则使用"other"
        si = self.signal_types.index(d['Signal']) if d['Signal'] in self.signal_types else self.n_signals - 1
        pi = self.patterns.index(d['Pattern']) if d['Pattern'] in self.patterns else self.n_patterns - 1
        ci = self.codes.index(d['Code']) if d['Code'] in self.codes else self.n_codes - 1
        
        signal_onehot[si] = 1
        pattern_onehot[pi] = 1
        code_onehot[ci] = 1
        
        # 其他特征
        entry_price = float(d.get('Entry Price', 1))
        volume = float(d.get('Volume', 0)) / entry_price
        volatility = float(d.get('Volatility', 0)) / entry_price
        tp = float(d.get('Take Profit', 0)) / entry_price
        sl = float(d.get('Stop Loss', 0)) / entry_price
        atr = float(d.get('ATR', 0)) / entry_price
        
        # 时间特征
        try:
            dt = pd.to_datetime(d['Timestamp'])
            hour_of_day = dt.hour / 23.0
            day_of_week = dt.weekday() / 6.0
        except Exception:
            hour_of_day = 0.0
            day_of_week = 0.0
        
        obs = np.concatenate([
            signal_onehot, pattern_onehot, code_onehot,
            [volume, volatility, tp, sl, hour_of_day, day_of_week, 1.0, atr]
        ])
        return obs
    
    def step(self, action):
        tp, sl = action
        d = self.data[self.current_step]
        entry_price = float(d.get('Entry Price', 1))
        best_tp = float(d['Take Profit']) / entry_price
        best_sl = float(d['Stop Loss']) / entry_price
        diff = abs(tp - best_tp) + abs(sl - best_sl)
        reward = 1.0 if diff < 0.01 else -1.0
        self.current_step += 1
        done = self.current_step >= len(self.data)
        obs = self._get_obs() if not done else np.zeros(self.observation_space.shape, dtype=np.float32)
        info = {}
        
        return obs, reward, done, info
    
    def render(self, mode='human'):
        pass
    
    def close(self):
        pass

# 训练代码
'''
with open("trade_signals.json", "r", encoding="utf-8") as f:
    signals_data = json.load(f)

env = TradingEnv(signals_data)
model = SAC('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=200_000)
model.save("sac_trading_agent")

# 保存类别字典
with open("categories_dict.json", "w") as f:
    json.dump(env.categories_dict, f)
'''

# 加载和预测代码
# 首先加载类别字典
with open("categories_dict.json", "r") as f:
    categories_dict = json.load(f)

# 加载测试数据
with open("test_trade_signals.json", "r", encoding="utf-8") as f:
    test_signals_data = json.load(f)

# 使用相同的类别字典创建测试环境
new_env = TradingEnv(test_signals_data, categories_dict=categories_dict)

# 加载模型
model = SAC.load("sac_trading_agent", env=new_env)

# 进行预测
obs = new_env.reset()
done = False
total_reward = 0
predictions = []

while not done:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, info = new_env.step(action)
    total_reward += reward
    
    # 获取当前数据用于记录
    if new_env.current_step - 1 < len(test_signals_data):
        current_data = test_signals_data[new_env.current_step - 1]
        predictions.append({
            'Signal': current_data['Signal'],
            'Pattern': current_data['Pattern'],
            'Code': current_data['Code'],
            'Predicted_TP': float(action[0]),
            'Predicted_SL': float(action[1]),
            'Actual_TP': float(current_data['Take Profit']) / float(current_data.get('Entry Price', 1)),
            'Actual_SL': float(current_data['Stop Loss']) / float(current_data.get('Entry Price', 1)),
            'Reward': reward
        })
        print(f"Signal: {current_data['Signal']}, TP: {action[0]:.4f}, SL: {action[1]:.4f}, Reward: {reward}")

print(f"\n总奖励: {total_reward}")
print(f"平均奖励: {total_reward / len(test_signals_data):.4f}")

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Signal: Buy Long, TP: 1.0309, SL: 0.9831, Reward: -1.0
Signal: Buy Long, TP: 1.1613, SL: 0.9756, Reward: -1.0
Signal: Buy Long, TP: 1.0268, SL: 0.9833, Reward: -1.0
Signal: Buy Long, TP: 1.0113, SL: 0.9889, Reward: 1.0
Signal: Buy Long, TP: 1.0113, SL: 0.9889, Reward: 1.0
Signal: Buy Long, TP: 1.0114, SL: 0.9888, Reward: 1.0
Signal: Buy Long, TP: 1.0110, SL: 0.9890, Reward: 1.0
Signal: Buy Long, TP: 1.1476, SL: 0.9717, Reward: -1.0
Signal: Buy Long, TP: 1.0107, SL: 0.9890, Reward: 1.0
Signal: Buy Long, TP: 1.0106, SL: 0.9891, Reward: 1.0
Signal: Buy Long, TP: 1.0105, SL: 0.9891, Reward: 1.0
Signal: Buy Long, TP: 1.0101, SL: 0.9890, Reward: 1.0
Signal: Buy Long, TP: 1.1205, SL: 0.9721, Reward: -1.0
Signal: Buy Long, TP: 1.1164, SL: 0.9725, Reward: -1.0
Signal: Buy Long, TP: 1.0144, SL: 0.9885, Reward: 1.0
Signal: Buy Long, TP: 1.0144, SL: 0.9886, Reward: 1.0
Signal: Buy Long, TP: 1.0141, SL: 0.9887, Reward: 1.0