In [1]:
import numpy as np

class PriceModel():
    '''
    Note:This is a stateless class, gathering price evolution models in one place
    '''
    def price_model_1(current_price, current_action, tau, vol_matrix, perm_impact_matrix, random_vector):
        return current_price + (tau**0.5) * (vol_matrix@(random_vector)) - perm_impact_matrix@current_action
    def price_model_2(current_price):
        pass

In [36]:
import gym
from gym import spaces
import pandas as pd

class LiquidationEnv(gym.Env):
    metadata = {'render.modes': ['human']}
    
    def __init__(self, 
                 n_assets=3, 
                 initial_shares=100, 
                 initial_prices=100, 
                 max_steps=5,
                 price_model=PriceModel,
                 tau = 1,
                 temp_price_matrix = np.identity(3),
                 vol_matrix = np.identity(3),
                 perm_impact_matrix = np.identity(3)
                 ):
        super(LiquidationEnv, self).__init__()
        
        # Environment parameters
        self.n_assets = n_assets
        self.initial_shares = np.full(n_assets, initial_shares, dtype=np.float32)
        self.initial_prices = np.full(n_assets, initial_prices, dtype=np.float32)
        self.max_steps = max_steps
        self.price_generator = price_model.price_model_1
        self.temp_price_matrix = temp_price_matrix
        self.tau = tau
        self.vol_matrix = vol_matrix
        self.perm_impact_matrix = perm_impact_matrix
        
        # Define action and observation spaces
        self.action_space = spaces.Box(
            low=0,
            high= 1,
            shape=(n_assets,),
            dtype=np.float32
        )
        
        self.observation_space = spaces.Dict({
            "prices": spaces.Box(low = -np.inf, high=np.inf, shape=(n_assets,), dtype=np.float32),
            "remaining": spaces.Box(low = 0, high=initial_shares, shape=(n_assets,), dtype=np.float32),
            "acc_revenue": spaces.Box(low = -np.inf, high=np.inf, shape=(1,), dtype=np.float32)
        })
        
        # Initialize state
        self.state = None
        self.current_step = 0
        self.reset()

    def _get_obs(self):
        return {
            "prices": self.state['prices'].copy().astype(np.float32),
            "remaining": self.state['remaining'].copy().astype(np.float32),
            "acc_revenue": np.array([self.state['acc_revenue']], dtype=np.float32)
        }

    def _next_price(self, current_price , current_action, tau, vol_matrix, perm_impact_matrix, random_vector):
        # actual_action = self.state['remaining'] * current_action
        return self.price_generator(current_price, current_action, tau, vol_matrix, perm_impact_matrix, random_vector)

    def reset(self):
        # Reset initial prices (customize with your price initialization)
        self.state = {
            'prices': self.initial_prices.copy(),
            'remaining': self.initial_shares.copy(),
            'acc_revenue': 0.0
        }
        self.current_step = 0
        return self._get_obs()
    
    def _get_reward(self, state, action, temp_price_matrix):
        '''
        The function to calculate the reward
        '''
        # actual_action = action * state['remaining']
        reward = action.dot(state['prices'] - temp_price_matrix.dot(action))
        return reward

    def step(self, action):
        # TODO: need a better way than clipping
        actual_action = self.state['remaining'] * action
        reward = self._get_reward(self.state, actual_action, self.temp_price_matrix)
        
        # Update state
        self.state['remaining'] -= actual_action
        random_vector = np.random.normal(size = self.n_assets)
        self.state['prices'] = self._next_price(self.state['prices'] , actual_action, self.tau, self.vol_matrix, self.perm_impact_matrix, random_vector)
        step_revenue = np.sum(actual_action * (self.state['prices'] - self.temp_price_matrix.dot(actual_action))) # Calculate revenue from current prices
        self.state['acc_revenue'] += step_revenue # TODO: what's the third part of the state? what's the formulor to calculate it?
        
        # Update step counter
        
        self.current_step += 1
        
        # Check termination conditions
        done = (np.sum(self.state['remaining']) <= 0) or (self.current_step >= self.max_steps)
        if done:
            reward = 0.
            
        
        return self._get_obs(), reward, done, {}

    def render(self, mode='human'):
        print(f"Step: {self.current_step}")
        print(f"Prices: {self.state['prices']}")
        print(f"Remaining: {self.state['remaining']}")
        print(f"Accumulated Revenue: {self.state['acc_revenue']:.2f}\n")
        
    def close(self):
        pass

In [7]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

# Create environment
env = LiquidationEnv(n_assets=3, initial_shares=100)
#env.render()
# Verify environment compatibility
# check_env(env)

# Create and train model
model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps = 50)

# # Test trained model
# obs = env.reset()
# for _ in range(100):

# if done:
#     break

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.84     |
|    ep_rew_mean     | 3.51e+03 |
| time/              |          |
|    fps             | 5906     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x734887f78df0>

In [37]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

# Create environment
env = LiquidationEnv(n_assets=3, initial_shares=100)
check_env(env)

In [None]:
check_env(env)

In [26]:
env.reset()
env.render()

Step: 0
Prices: [100. 100. 100.]
Remaining: [100. 100. 100.]
Accumulated Revenue: 0.00



In [34]:
action = np.ones(3)*0.01
obs, rewards, done, info = env.step(action)
print(done)
env.render()

True
Step: 8
Prices: [92.45442804 97.5168319  93.08483972]
Remaining: [92.27447 92.27447 92.27447]
Accumulated Revenue: 2224.51

