In [2]:
%pip install gymnasium
%pip install numpy



In [3]:
from typing import Optional
import numpy as np
import gymnasium as gym
from gymnasium import spaces

## Base Env

In [5]:
class STBaseEnv(gym.Env):
  def __init__(self, data, ibalance=1e4):
    super(STBaseEnv, self).__init__()
    self.data = data
    self.ibalance = ibalance
    self.action_space = spaces.Discrete(3)
    self.balance = ibalance
    self.current_step = 0
    self.position = 0

    # data.shape[1] is the num of features, 2 is for balance and position
    self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(data.shape[1]+2))
    self.reset()

  def reset(self):
    self.balance = self.ibalance
    self.position = 0
    self.current_step =0

    return self._get_observation()

  def step(self, action):
    obs = self._get_observation()
    reward = 0  # calculate reward

    # leave for implementation
    if self.current_step == self.data.shape[0] - 1:
      terminated = True
    else:
      terminated = False

    # leave for implementation
    truncated = False
    info = self._get_info()

    return obs, reward, terminated, truncated, info

    def _get_observation(self):
      obs = np.concatenate([self.data[self.current_step], [self.balance, self.position]])
      return obs

    def _get_info(self):
      pass

    def render(self):



In [None]:
import matplotlib.pyplot as plt
import numpy as np
import gymnasium as gym
from gymnasium import spaces

class STBaseEnv(gym.Env):
  metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': 10} # 添加 rgb_array 模式

  def __init__(self, data, ibalance=1e4, render_mode=None):
    super(STBaseEnv, self).__init__()
    self.data = data
    self.ibalance = ibalance
    self.action_space = spaces.Discrete(3) # 0: hold, 1: buy, 2: sell
    self.balance = ibalance
    self.current_step = 0
    self.position = 0

    self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(data.shape[1]+2), dtype=np.float32)

    self.render_mode = render_mode
    self.fig = None
    self.ax = None
    self.line_price = None
    self.line_value = None
    self.buy_scatter = None
    self.sell_scatter = None

    self.prices_history = []
    self.total_value_history = []
    self.buy_steps = []
    self.sell_steps = []

    self.reset()

  def reset(self, seed=None, options=None):
    super().reset(seed=seed)

    self.balance = self.ibalance
    self.position = 0
    self.current_step = 0
    self.prices_history = []
    self.total_value_history = []
    self.buy_steps = []
    self.sell_steps = []

    observation = self._get_observation()
    info = self._get_info()
    return observation, info

  def step(self, action):
    current_price = self.data[self.current_step, 0] # 假设第一个特征是价格
    reward = 0

    # 记录当前价格和总价值
    self.prices_history.append(current_price)
    self.total_value_history.append(self.balance + self.position * current_price)

    # 交易逻辑和奖励计算
    if action == 1: # Buy
        if self.balance >= current_price:
            self.position += 1
            self.balance -= current_price
            reward = 0.001 * (self.data[self.current_step+1, 0] - current_price) if self.current_step + 1 < self.data.shape[0] else 0 # 简单未来收益奖励
            self.buy_steps.append(self.current_step)
    elif action == 2: # Sell
        if self.position >= 1:
            self.position -= 1
            self.balance += current_price
            reward = 0.001 * (current_price - self.data[self.current_step+1, 0]) if self.current_step + 1 < self.data.shape[0] else 0 # 简单未来收益奖励
            self.sell_steps.append(self.current_step)
    else: # Hold
        pass

    self.current_step += 1

    if self.current_step == self.data.shape[0] - 1:
      terminated = True
    else:
      terminated = False

    truncated = False
    obs = self._get_observation()
    info = self._get_info()

    if self.render_mode == "human":
        self.render()
    elif self.render_mode == "rgb_array":
        return self.render() # 返回图像数组

    return obs, reward, terminated, truncated, info

  def _get_observation(self):
    if self.current_step >= self.data.shape[0]: # 确保不会越界
        # 返回最后一个有效观测，或者处理结束状态
        return np.concatenate([self.data[-1], [self.balance, self.position]])
    return np.concatenate([self.data[self.current_step], [self.balance, self.position]])

  def _get_info(self):
    current_price = self.data[self.current_step, 0] if self.current_step < self.data.shape[0] else self.data[-1, 0] # 避免越界
    total_value = self.balance + self.position * current_price
    return {"total_value": total_value, "balance": self.balance, "position": self.position}

  def render(self):
    if self.render_mode is None:
        return None

    if self.fig is None:
        self.fig, self.ax = plt.subplots(figsize=(12, 6))
        self.line_price, = self.ax.plot([], [], label='Price')
        self.line_value, = self.ax.plot([], [], label='Total Value', color='orange')
        self.buy_scatter = self.ax.scatter([], [], color='green', marker='^', s=100, label='Buy')
        self.sell_scatter = self.ax.scatter([], [], color='red', marker='v', s=100, label='Sell')
        self.ax.legend()
        self.ax.set_title('Trading Simulation')
        self.ax.set_xlabel('Time Step')
        self.ax.set_ylabel('Value')
        self.ax.grid(True)
        plt.ion() # Turn on interactive mode
        plt.show()

    # 更新数据
    x_data = np.arange(len(self.prices_history))
    self.line_price.set_data(x_data, self.prices_history)
    self.line_value.set_data(x_data, self.total_value_history)

    # 更新买入/卖出点
    if self.buy_steps:
        buy_x = np.array(self.buy_steps)
        buy_y = np.array([self.prices_history[step] for step in self.buy_steps])
        self.buy_scatter.set_offsets(np.c_[buy_x, buy_y])
    if self.sell_steps:
        sell_x = np.array(self.sell_steps)
        sell_y = np.array([self.prices_history[step] for step in self.sell_steps])
        self.sell_scatter.set_offsets(np.c_[sell_x, sell_y])

    # 自动调整X轴范围
    self.ax.set_xlim(0, max(self.current_step + 1, 10))

    # 自动调整Y轴范围
    min_val = min(min(self.prices_history) if self.prices_history else 0, min(self.total_value_history) if self.total_value_history else 0)
    max_val = max(max(self.prices_history) if self.prices_history else 0, max(self.total_value_history) if self.total_value_history else 0)
    self.ax.set_ylim(min_val * 0.9, max_val * 1.1)

    self.fig.canvas.draw()
    self.fig.canvas.flush_events()

    if self.render_mode == "rgb_array":
        # Capture the plot as an image array
        self.fig.canvas.draw()
        img = np.frombuffer(self.fig.canvas.tostring_rgb(), dtype=np.uint8)
        img = img.reshape(self.fig.canvas.get_width_height()[::-1] + (3,))
        return img

  def close(self):
    if self.fig is not None:
        plt.close(self.fig)
        self.fig = None
        self.ax = None

# 示例使用
if __name__ == "__main__":
    # 生成更多数据点以便可视化效果更明显
    prices = np.sin(np.linspace(0, 20, 200)) * 50 + 150 # 200个时间步
    mock_data = prices.reshape(-1, 1)

    # 测试 human 模式
    env_human = STBaseEnv(data=mock_data, ibalance=10000, render_mode="human")
    obs, info = env_human.reset()
    done = False
    total_reward = 0

    print("--- Starting Simulation (Matplotlib Human Render) ---")
    for _ in range(mock_data.shape[0] -1): # 运行到数据结束
        action = np.random.randint(0, 3) # 随机行动
        obs, reward, terminated, truncated, info = env_human.step(action)
        total_reward += reward
        if terminated or truncated:
            break
    env_human.close()
    print(f"\n--- Simulation Finished (Human Render) ---")
    print(f"Final Total Value: {info['total_value']:.2f}")
    print(f"Total Reward: {total_reward:.2f}")

    # 测试 rgb_array 模式 (用于保存视频或集成到其他可视化工具)
    # env_rgb = STBaseEnv(data=mock_data, ibalance=10000, render_mode="rgb_array")
    # obs, info = env_rgb.reset()
    # frames = []
    # done = False
    # while not done:
    #     action = np.random.randint(0, 3)
    #     obs, reward, terminated, truncated, info = env_rgb.step(action)
    #     frame = env_rgb.render()
    #     if frame is not None:
    #         frames.append(frame)
    #     done = terminated or truncated
    # env_rgb.close()
    # print(f"\n--- Simulation Finished (RGB Array Render) ---")
    # print(f"Collected {len(frames)} frames.")
    # # 你可以使用 imageio 或 moviepy 将这些帧保存为视频

## Wrappers

In [None]:
class MAIndicatorWrapper(gym.ObservationWrapper):
  def __init__(self, env, window=5):
    super(MAIndicatorWrapper, self).__init__(env)
    self.window = window

  def observation(self, observation):
    ma = self.env.data[max(0, self.env.current_step-self.window+1):self.env.current_step+1, 0].mean()
    return np.append(observation, ma)