In [71]:
import gymnasium
from gymnasium import spaces
import numpy as np

class StockTradingEnv(gym.Env):
    def __init__(self, data):
        super().__init__()

        # Stock price data (time series of stock prices)
        self.data = data
        self.n = len(data)

        # Action space (Buy, Sell, Hold)
        self.action_space = spaces.Discrete(3)

        # Observation space (current stock price, number of stocks held)
        self.observation_space = spaces.Box(low=0, high=np.inf, shape=(2,))

        # Initial balance and stock holdings
        self.balance = 10000  # Initial balance
        self.stock_held = 0   # No stocks initially
        self.current_step = 0

    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)
        self.current_step = 0
        self.balance = 10000
        self.stock_held = 0
        return np.array([self.data[self.current_step], self.stock_held]).astype(np.float32), {}

    def step(self, action):
        assert self.action_space.contains(action)

        current_price = self.data[self.current_step]
        next_price = self.data[self.current_step + 1]

        if action == 0:  # Buy
            max_shares = int(self.balance / current_price)
            shares_bought = min(max_shares, 10)  # Buy at most 10 shares at a time
            cost = shares_bought * current_price
            self.balance -= cost
            self.stock_held += shares_bought
        elif action == 1:  # Sell
            shares_sold = min(self.stock_held, 10)  # Sell at most 10 shares at a time
            revenue = shares_sold * current_price
            self.balance += revenue
            self.stock_held -= shares_sold

        self.current_step += 1

        # Calculate reward (change in portfolio value)
        portfolio_value = self.balance + (self.stock_held * next_price)
        prev_portfolio_value = self.balance + (self.stock_held * current_price)
        reward = portfolio_value - prev_portfolio_value
        terminated = self.balance <= 0 or self.current_step >= self.n - 1
        truncated = False

        done = self.current_step == self.n - 1  # End of data

        info = {}

        return obs, reward, terminated, truncated, info

    def render(self, mode="human", close=False):
        pass

    def close(self):
        pass


In [72]:
from stable_baselines3.common.env_checker import check_env

prices = np.random.randint(1, 100, size=1000)
env = StockTradingEnv(prices)

check_env(env)

AssertionError: The observation returned by the `step()` method does not match the shape of the given observation space Box(0.0, inf, (2,), float32). Expected: (2,), actual shape: (1,)