In [81]:
import numpy as np
import numpy.typing as npt
import typing
import gymnasium as gym
from gymnasium import spaces
import python_modules.Server as Server
from typing import Callable, List, Dict, Tuple, Any

In [82]:
def BlackScholesPathSingle(
    S_0: float, 
    r: float, 
    sigma: float, 
    T: float, 
    N: int, 
    rng: np.random.Generator
):
    delta = T / N
    Z = rng.normal(loc=0, scale=1, size=(N))
    path = np.empty((N + 1), dtype=np.float64)
    path[0] = S_0
    for i in range(N):
        path[i + 1] = path[i] * np.exp((r - 0.5 * np.power(sigma, 2.0)) * delta + sigma * np.sqrt(delta) * Z[i])
    return path

class MarginCash(Server.ISecurity):
    def __init__(self, ticker: str, margin_interest_rate: float):
        super().__init__()
        self.ticker = ticker
        self.margin_interest_rate = margin_interest_rate
        pass
    def after_step(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager):
        dt = simulation.get_dt()
        margin_cash_id = simulation.get_security_id(self.ticker)
        for user_id in range(portfolio.get_user_count()):
            portfolio.multiply_to_security_if_negative(user_id, margin_cash_id, 1 + dt * self.margin_interest_rate)
            pass
        return None
    def before_step(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager):
        return None
    def is_tradeable(self) -> bool:
        return False
    def on_simulation_end(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager):
        return None
    def on_simulation_start(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager):
        return None
    def on_trade_executed(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager, buyer_id: typing.SupportsInt, seller_id: typing.SupportsInt, transacted_price: typing.SupportsFloat, transacted_volume: typing.SupportsFloat):
        return None
    pass

class DividendStock(Server.ISecurity):
    def __init__(self, ticker: str, currency: str, dividend_function: Callable[[int], float]):
        super().__init__()
        self.ticker = ticker
        self.currency = currency
        self.dividend_function = dividend_function
        pass
    def after_step(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager):
        t = simulation.get_t()
        dt = simulation.get_dt()
        dividend = self.dividend_function(simulation.get_tick())
        stock_id = simulation.get_security_id(self.ticker)
        currency_id = simulation.get_security_id(self.currency)
        for user_id in range(portfolio.get_user_count()):
            portfolio.multiply_and_add_1_to_2(user_id, stock_id, currency_id, dt * dividend)
            pass
        return None
    def before_step(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager):
        return None
    def is_tradeable(self) -> bool:
        return False
    def on_simulation_end(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager):
        return None
    def on_simulation_start(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager):
        return None
    def on_trade_executed(self, simulation: Server.ISimulation, portfolio: Server.IPortfolioManager, buyer_id: typing.SupportsInt, seller_id: typing.SupportsInt, transacted_price: typing.SupportsFloat, transacted_volume: typing.SupportsFloat):
        stock_id = simulation.get_security_id(self.ticker)
        cash_id = simulation.get_security_id(self.currency)
        portfolio.add_to_two_securities(buyer_id, stock_id, transacted_volume, cash_id, -transacted_price * transacted_volume)
        portfolio.add_to_two_securities(seller_id, stock_id, -transacted_volume, cash_id, transacted_price * transacted_volume)
        return None
    pass

class MultiAgentSimulationEnv(gym.Env):    
    def _get_agent_dividend_estimate(self, agent_id: int) -> float:
        tick = self.simulation.get_tick()
        annualized_dividend = self.annualized_dividend_path[tick]
        analyst_mean_error = self.agent_analyst_mean_error[agent_id]
        analyst_error = self.rng.normal(analyst_mean_error, 0.001)
        return min(annualized_dividend + analyst_error, 0.0)
    
    def __init__(self, agent_count: int):
        super().__init__()
        self.rng = np.random.default_rng()
        self.step_count = 1000
        
        self.cash_borrowing_cost = 0.08
        self.bond_interest_rate = 0.05
        self.bond_face_value = 100.0
        
        self.dividend_initial_value = 0.06
        self.dividend_growth_rate = 0.1
        self.dividend_volatility = 0.1
        self.annualized_dividend_path = BlackScholesPathSingle(
            self.dividend_initial_value, 
            self.dividend_growth_rate,
            self.dividend_volatility,
            1.0,
            self.step_count,
            self.rng
        )
        
        self.MAX_STOCK_POSITION = 20
        self.MAX_BOND_POSITION = 30
        self.lambda_risk = 0.001
        
        self.initial_stock_price = self.step_count * self.dividend_initial_value
        
        self.cash = MarginCash("CASH", self.cash_borrowing_cost)
        self.bond = Server.GenericSecurities.GenericBond("BOND", "CASH", self.bond_interest_rate, self.bond_face_value) 
        self.stock = DividendStock("STOCK", "CASH", dividend_function=lambda t, dt: self._get_dividend(t, dt))
        
        self.simulation = Server.GenericSimulation({
            "CASH": self.cash,
            "BOND": self.bond,
            "STOCK": self.stock    
        }, 1.0, self.step_count)
        
        self.cash_id = self.simulation.get_security_id("CASH")
        self.bond_id = self.simulation.get_security_id("BOND")
        self.stock_id = self.simulation.get_security_id("STOCK")
        
        self.agent_ids: set[int] = set()
        self.previous_portfolio_value: dict[int, float] = {}
        self.agent_analyst_mean_error: dict[int, float] = {}
        for agent_id in range(agent_count):
            id = self.simulation.add_user(f"AGENT_{agent_id}")
            self.agent_ids.add(id)
            self.previous_portfolio_value[id] = 0.0
            self.agent_analyst_mean_error[id] = self.rng.uniform(-0.01, 0.01)
            pass
        
        self.past_bond_midpoints: List[float] = [self.bond_face_value]
        self.past_stock_midpoints: List[float] = [self.initial_stock_price]
        self.observation_book_length = 5 # Top 5
        
        self.observation_spaces = {
            agent_id: spaces.Dict({
                "order_book_bond_asks": spaces.Box(low=0.0, high=np.inf, shape=(self.observation_book_length,2), dtype=np.float32), # (price, volume)
                "order_book_bond_bids": spaces.Box(low=0.0, high=np.inf, shape=(self.observation_book_length,2), dtype=np.float32), # (price, volume)
                "order_book_stock_asks": spaces.Box(low=0.0, high=np.inf, shape=(self.observation_book_length,2), dtype=np.float32), # (price, volume)
                "order_book_stock_bids": spaces.Box(low=0.0, high=np.inf, shape=(self.observation_book_length,2), dtype=np.float32), # (price, volume)
                "cash_borrowing_cost": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "bond_interest_rate": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "bond_face_value": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "bond_midpoint_price": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "stock_midpoint_price": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "bond_volatility_5": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "bond_volatility_20": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "bond_volatility_100": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "stock_volatility_5": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "stock_volatility_20": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "stock_volatility_100": spaces.Box(low=0.0, high=np.inf, dtype=np.float32),
                "bond_price_delta": spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32),
                "stock_price_delta": spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32),
                "current_expected_dividend": spaces.Box(low=0.0, high=np.inf, dtype=np.float32), # The expected dividend per share we will receive if we own by the end of the current tick
                "remaining_bond_position": spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32), # The remaining position until the limit is reached (it can be breached using limit orders, that's why we also have negative)
                "remaining_stock_position": spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32),
                "open_orders_count": spaces.Box(low=0, high=np.inf, shape=(2,), dtype=np.int32), # (bond, stock)
                "positions": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32), # (cash, bond, stock)
            }) for agent_id in self.agent_ids
        }

        self.RELATIVE_PRICE_BINS = 201
        self.action_spaces = {agent_id: gym.spaces.MultiDiscrete([
            4, # Order type: cancel (all own open limit orders), limit, market, (nothing),
            2, # Security: 0=bond, 1=stock
            2, # Side: 0=buy, 1=sell
            self.RELATIVE_PRICE_BINS, # Relative price (gets normalized to [-1$, 1$] inclusive)
            5  # Quantity: (1-5)
        ]) for agent_id in self.agent_ids}
        pass
    
    def _get_midpoints(self):
        bond_midpoint = self.bond_face_value
        if self.simulation.get_ask_count(self.bond_id) > 0 and self.simulation.get_bid_count(self.bond_id) > 0:
            bond_midpoint = (self.simulation.get_top_ask(self.bond_id).price + self.simulation.get_top_bid(self.bond_id).price) / 2
            
        stock_midpoint = self.initial_stock_price
        if self.simulation.get_ask_count(self.stock_id) > 0 and self.simulation.get_bid_count(self.stock_id) > 0:
            stock_midpoint = (self.simulation.get_top_ask(self.stock_id).price + self.simulation.get_top_bid(self.stock_id).price) / 2
        
        return (bond_midpoint, stock_midpoint)
    
    def _get_obs(self):
        bond_midpoint, stock_midpoint = self._get_midpoints()

        bond_bids, bond_asks = self.simulation.get_order_book(self.bond_id)
        stock_bids, stock_asks = self.simulation.get_order_book(self.stock_id)
        
        def book_to_array(orders: list[Server.LimitOrder]):
            length = self.observation_book_length
            arr = np.zeros((length, 2), dtype=np.float32)
            for i in range(min(length, len(orders))):
                arr[i, 0] = orders[i].price
                arr[i, 1] = orders[i].volume
            return arr

        bond_ask_arr = book_to_array(bond_asks)
        bond_bid_arr = book_to_array(bond_bids)
        stock_ask_arr = book_to_array(stock_asks)
        stock_bid_arr = book_to_array(stock_bids)

        bond_midpoints_5 = self.past_bond_midpoints[-5:]
        bond_midpoints_20 = self.past_bond_midpoints[-20:]
        bond_midpoints_100 = self.past_bond_midpoints[-100:]
        stock_midpoints_5 = self.past_stock_midpoints[-5:]
        stock_midpoints_20 = self.past_stock_midpoints[-20:]
        stock_midpoints_100 = self.past_stock_midpoints[-100:]

        bond_volatility_5 = np.std(bond_midpoints_5) if len(bond_midpoints_5) >= 5 else 0.0
        bond_volatility_20 = np.std(bond_midpoints_20) if len(bond_midpoints_20) >= 20 else 0.0
        bond_volatility_100 = np.std(bond_midpoints_100) if len(bond_midpoints_100) >= 100 else 0.0
        
        stock_volatility_5 = np.std(stock_midpoints_5) if len(stock_midpoints_5) >= 5 else 0.0
        stock_volatility_20 = np.std(stock_midpoints_20) if len(stock_midpoints_20) >= 20 else 0.0
        stock_volatility_100 = np.std(stock_midpoints_100) if len(stock_midpoints_100) >= 100 else 0.0
        
        bond_price_delta = (self.past_bond_midpoints[-1] - self.past_bond_midpoints[-2]) if len(self.past_bond_midpoints) >= 2 else 0.0
        stock_price_delta = (self.past_stock_midpoints[-1] - self.past_stock_midpoints[-2]) if len(self.past_stock_midpoints) >= 2 else 0.0

        obs = {}

        for agent_id in self.agent_ids:
            portfolio = self.simulation.get_user_portfolio(agent_id)
            bond_position = portfolio[self.bond_id]
            stock_position = portfolio[self.stock_id]
            cash_position = portfolio[self.cash_id]

            open_bond_orders = len(self.simulation.get_all_open_user_orders(agent_id, self.bond_id))
            open_stock_orders = len(self.simulation.get_all_open_user_orders(agent_id, self.stock_id))

            remaining_bond_position = (self.MAX_BOND_POSITION - bond_position)
            remaining_stock_position = (self.MAX_STOCK_POSITION - stock_position)

            current_expected_dividend = self._get_agent_dividend_estimate(agent_id)

            obs[agent_id] = {
                "order_book_bond_asks": bond_ask_arr.copy(),
                "order_book_bond_bids": bond_bid_arr.copy(),
                "order_book_stock_asks": stock_ask_arr.copy(),
                "order_book_stock_bids": stock_bid_arr.copy(),
                "cash_borrowing_cost": self.cash_borrowing_cost,
                "bond_interest_rate": self.bond_interest_rate,
                "bond_face_value": self.bond_face_value,
                "bond_midpoint_price": bond_midpoint,
                "stock_midpoint_price": stock_midpoint,
                "bond_volatility_5": bond_volatility_5,
                "bond_volatility_20": bond_volatility_20,
                "bond_volatility_100": bond_volatility_100,
                "stock_volatility_5": stock_volatility_5,
                "stock_volatility_20": stock_volatility_20,
                "stock_volatility_100": stock_volatility_100,
                "bond_price_delta": bond_price_delta,
                "stock_price_delta": stock_price_delta,
                "current_expected_dividend": current_expected_dividend,
                "remaining_bond_position": remaining_bond_position,
                "remaining_stock_position": remaining_stock_position,
                "open_orders_count": np.array([open_bond_orders, open_stock_orders], dtype=np.int32),
                "positions": np.array([cash_position, bond_position, stock_position], dtype=np.float32),
            }
            pass

        return obs

    def reset(self, seed=None, options=None):
        self.simulation.reset_simulation()
        bond_midpoint, stock_midpoint = self._get_midpoints()
        self.past_bond_midpoints = [bond_midpoint]
        self.past_stock_midpoints = [stock_midpoint]
        for agent_id in self.agent_ids:
            self.previous_portfolio_value[agent_id] = 0.0
            self.agent_analyst_mean_error[agent_id] = self.rng.uniform(-0.01, 0.01)
            pass
        self.annualized_dividend_path = BlackScholesPathSingle(
            self.dividend_initial_value, 
            self.dividend_growth_rate,
            self.dividend_volatility,
            1.0,
            self.step_count,
            self.rng
        )
        obs = self._get_obs()
        return obs, {}
    
    def step(self, actions: dict[int, npt.NDArray[np.int64]]):
        items = list(actions.items())
        self.rng.shuffle(items)
        
        bond_midpoint, stock_midpoint = self._get_midpoints()
        
        for agent_id, action in items:
            agent_portfolio = self.simulation.get_user_portfolio(agent_id)
            
            action_tuple = tuple(map(lambda x: x, action))
            
            match action_tuple:
                case (0, _, _, _, _): # Cancel all of the agent's limit orders
                    bond_orders = self.simulation.get_all_open_user_orders(agent_id, self.bond_id)
                    for bond_order in bond_orders:    
                        self.simulation.submit_cancel_order(agent_id, self.bond_id, bond_order)
                    stock_orders = self.simulation.get_all_open_user_orders(agent_id, self.stock_id)
                    for stock_order in stock_orders:
                        self.simulation.submit_cancel_order(agent_id, self.stock_id, stock_order)
                    pass
                case (1, security, side, relative_price_idx, quantity): # Limit order
                    quantity += 1 # Since the discrete starts at 0
                    # Technically the agent can go beyond the max positions using limit order:
                    # i.e. they submit a bunch but none is executed, then quickly all get executed
                    # resulting in them exceeding the position limit. In this case we only allow
                    # reduce to close their positions.
                    relative_price = -1.0 + (relative_price_idx) / (self.RELATIVE_PRICE_BINS - 1)  # From bin to cents
                    if security == 0:
                        # Bond
                        position = agent_portfolio[self.bond_id]
                        price = round(bond_midpoint + relative_price, 2)
                        if side == 0 and position + quantity <= self.MAX_BOND_POSITION:
                            # Buy
                            self.simulation.submit_limit_order(agent_id, self.bond_id, Server.OrderSide.BID, price, quantity)
                            pass
                        elif side == 1 and position - quantity >= -self.MAX_BOND_POSITION:
                            # Sell
                            self.simulation.submit_limit_order(agent_id, self.bond_id, Server.OrderSide.ASK, price, quantity)
                            pass
                        pass
                    elif security == 1:
                        # Stock
                        position = agent_portfolio[self.stock_id]
                        price = round(stock_midpoint + relative_price, 2)
                        if side == 0 and position + quantity <= self.MAX_STOCK_POSITION:
                            # Buy
                            self.simulation.submit_limit_order(agent_id, self.stock_id, Server.OrderSide.BID, price, quantity)
                            pass
                        elif side == 1 and position - quantity >= -self.MAX_STOCK_POSITION:
                            # Sell
                            self.simulation.submit_limit_order(agent_id, self.stock_id, Server.OrderSide.ASK, price, quantity)
                            pass
                        pass
                    pass
                case (2, security, side, _, quantity): # Market order
                    quantity += 1
                    # We allow agents to close positions even if outside limits
                    if security == 0:
                        # Bond
                        position = agent_portfolio[self.bond_id]
                        if side == 0 and position + quantity <= self.MAX_BOND_POSITION:
                            # Buy
                            self.simulation.submit_market_order(agent_id, self.bond_id, Server.OrderAction.BUY, quantity)
                            pass
                        elif side == 1 and position - quantity >= -self.MAX_BOND_POSITION:
                            # Sell
                            self.simulation.submit_market_order(agent_id, self.bond_id, Server.OrderAction.SELL, quantity)
                            pass
                        pass
                    elif security == 1:
                        # Stock
                        position = agent_portfolio[self.stock_id]
                        if side == 0 and position + quantity <= self.MAX_STOCK_POSITION:
                            # Buy
                            self.simulation.submit_market_order(agent_id, self.stock_id, Server.OrderAction.BUY, quantity)
                            pass
                        elif side == 1 and position - quantity >= -self.MAX_STOCK_POSITION:
                            # Sell
                            self.simulation.submit_market_order(agent_id, self.stock_id, Server.OrderAction.SELL, quantity)
                            pass
                        pass
                    pass
                case (3, _, _, _, _): # Do nothing
                    pass
                case unmatched:
                    print("Received unmatched case")
                    print(type(unmatched))
                    raise Exception("Received unmatched case.")
            pass
        
        result = self.simulation.do_simulation_step()
        bond_midpoint, stock_midpoint = self._get_midpoints()
        self.past_bond_midpoints.append(bond_midpoint)
        self.past_stock_midpoints.append(stock_midpoint)
        
        terminated = not result.has_next_step
        obs = self._get_obs()
        
        rewards = {}
        for agent_id in self.agent_ids:
            portfolio = self.simulation.get_user_portfolio(agent_id)
            cash = portfolio[self.cash_id]
            bond = portfolio[self.bond_id]
            stock = portfolio[self.stock_id]
            
            current_portfolio_value = cash + bond * bond_midpoint + stock * stock_midpoint
            previous_portfolio_value = self.previous_portfolio_value[agent_id]
            
            risk_penalty = self.lambda_risk * (bond ** 2 + stock ** 2)
            
            pnl_reward = current_portfolio_value - previous_portfolio_value
            rewards[agent_id] = np.clip(pnl_reward - risk_penalty, -10.0, 10.0)
            self.previous_portfolio_value[agent_id] = current_portfolio_value
            pass
        return obs, rewards, terminated, False, {}
    pass

In [83]:
# trading_parallel_env.py
from pettingzoo.utils import ParallelEnv

# Assume MultiAgentSimulationEnv is the class you already wrote.
class TradingParallelEnv(ParallelEnv):
    metadata = {"name": "trading_parallel_v0"}

    def __init__(self, agent_count: int = 20):
        super().__init__()
        self.base_env = MultiAgentSimulationEnv(agent_count=agent_count)

        # Map stable string names (“agent_0”, …) to the integer IDs used inside base_env
        self.agents = [f"agent_{i}" for i in range(agent_count)]
        self.possible_agents = self.agents[:]          # required by PettingZoo
        self._name_to_id = {
            name: id_ for name, id_ in zip(self.agents, sorted(self.base_env.agent_ids))
        }

        # Single-agent spaces (identical for every agent)
        example_id = next(iter(self.base_env.agent_ids))
        self._observation_space = self.base_env.observation_spaces[example_id]
        self._action_space = self.base_env.action_spaces[example_id]

    # ------------------------------------------------------------------ #
    #  PettingZoo API                                                    #
    # ------------------------------------------------------------------ #
    def observation_space(self, agent: str) -> spaces.Space:
        return self._observation_space

    def action_space(self, agent: str) -> spaces.Space:
        return self._action_space

    def reset(
        self, seed: int | None = None, options: Dict[str, Any] | None = None
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, Dict]]:
        obs_dict, info = self.base_env.reset(seed=seed, options=options or {})
        # Translate integer IDs → agent names
        obs = {name: obs_dict[id_] for name, id_ in self._name_to_id.items()}
        infos = {name: info for name in self.agents}
        return obs, infos

    def step(
        self, actions: Dict[str, np.ndarray]
    ) -> Tuple[
        Dict[str, np.ndarray],
        Dict[str, float],
        Dict[str, bool],
        Dict[str, bool],
        Dict[str, Dict],
    ]:
        # Translate names → integer IDs so your simulator stays unchanged
        joint_action = {self._name_to_id[a]: act for a, act in actions.items()}

        obs_dict, rew_dict, terminated, truncated, info = self.base_env.step(joint_action)

        # Convert back to name-keyed dicts
        obs      = {a: obs_dict[self._name_to_id[a]] for a in self.agents}
        rewards  = {a: rew_dict[self._name_to_id[a]] for a in self.agents}
        dones    = {a: terminated or truncated for a in self.agents}
        truncs   = dones.copy()
        infos    = {a: info for a in self.agents}

        # PettingZoo expects `agents` to be emptied when episode is over
        if terminated or truncated:
            self.agents = []

        return obs, rewards, dones, truncs, infos


In [None]:
# train_parallel_ppo.py
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize

# ------------------------------------------------------------------ #
#  Parameters you can tweak                                          #
# ------------------------------------------------------------------ #
AGENT_COUNT         = 20
NUM_PARALLEL_ENVS   = 8          # change at will
TOTAL_TIMESTEPS     = 1_000_000
PPO_N_STEPS         = 2048
PPO_BATCH_SIZE      = 512
LEARNING_RATE       = 3e-4
SEED                = 0
# ------------------------------------------------------------------ #

def make_raw_env():
    return TradingParallelEnv(agent_count=AGENT_COUNT)

# 1.  Create the raw PettingZoo env
raw_env = make_raw_env()

# 2.  Standard SuperSuit wrappers
env = ss.black_death_v3(raw_env)             # handles agents that finish early
env = ss.pad_observations_v0(env)
env = ss.pad_action_space_v0(env)

# 3.  Convert to a Stable-Baselines “VecEnv”
env = ss.pettingzoo_env_to_vec_env_v1(env)

# 4.  Add parallelism (Subproc)  ------------------------------------
env = ss.concat_vec_envs_v1(
    env,
    NUM_PARALLEL_ENVS,
    num_cpus=NUM_PARALLEL_ENVS,
    base_class="subproc",
)

# 5.  Normalise obs & rewards for stability
env = VecNormalize(env, norm_obs=True, norm_reward=True)

# 6.  PPO with a shared network (MultiInputPolicy handles Dict obs)
model = PPO(
    policy="MultiInputPolicy",
    env=env,
    learning_rate=LEARNING_RATE,
    n_steps=PPO_N_STEPS,
    batch_size=PPO_BATCH_SIZE,
    verbose=1,
    seed=SEED,
    tensorboard_log="./v4/tensorboard_logs",
    device="cpu"
)

# 7.  Train!
model.learn(total_timesteps=TOTAL_TIMESTEPS)

# 8.  Optional: save artefacts
model.save("./v4/ppo_trading_shared")
env.save("./v4/vecnorm.pkl")
