In [2]:
import gym
from gym import spaces

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

from warnings import filterwarnings
filterwarnings("ignore")

In [3]:
data = pd.read_csv("workflow/data/features.csv", index_col=0, parse_dates=True)

data.head()

Unnamed: 0_level_0,AGG_close,DBC_close,VTI_close,^VIX_close,AGG_return,DBC_return,VTI_return,^VIX_return,AGG_std,DBC_std,VTI_std,^VIX_std,AGG_momentum,DBC_momentum,VTI_momentum,^VIX_momentum
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
2006-04-19,57.932709,23.455568,47.14246,11.32,0.000915,0.009699,0.002983,-0.007042,0.002232,0.012989,0.005776,0.042326,-0.010673,0.070248,0.040617,-0.131902
2006-04-20,57.844398,23.183884,47.207253,11.64,-0.001526,-0.011651,0.001373,0.027876,0.002239,0.012361,0.005571,0.042082,-0.011491,0.089362,0.052293,-0.143488
2006-04-21,57.838486,23.672916,47.185673,11.59,-0.000102,0.020874,-0.000457,-0.004305,0.002239,0.012621,0.005503,0.041344,-0.011098,0.117094,0.044298,-0.096648
2006-04-24,58.056412,23.201998,47.04884,11.75,0.003761,-0.020093,-0.002904,0.013711,0.002306,0.012964,0.005515,0.041258,-0.007967,0.084674,0.043257,-0.104421
2006-04-25,57.862076,23.274446,46.897617,11.75,-0.003353,0.003118,-0.003219,0.0,0.002334,0.012639,0.005542,0.041185,-0.00921,0.108236,0.037759,-0.087024


In [11]:
class TradeEnv(gym.Env):
    def __init__(self, data, starting_balance, pos_size=10_000):
        super().__init__()

        self.data = data
        self.n_features = len(data.columns)
        self.close_prices = data[[col for col in data if "close" in col]]
        self.stocks = [stock.replace("_close", "") for stock in self.close_prices.columns]
        
        self.holdings = {stock: {"shares": 0, "balance": 0} for stock in self.stocks}
        self.starting_balance = starting_balance
        self.current_balance = starting_balance
        self.current_step = 0
        self.done = False

        lows = np.array([data[col].min() for col in data])
        highs = np.array([data[col].max() for col in data])
        self.n_stocks = len(self.stocks)
        self.observation_space = spaces.Box(low=lows, high=highs, shape=(self.n_features,), dtype=np.float32)
        self.action_space = spaces.MultiDiscrete([3]*self.n_stocks)
        self.pos_size = pos_size
    
    def step(self, actions):
        self.current_step += 1
        date = self.data.index[self.current_step]
        reward = 0

        if self.current_step >= len(self.data) - 1:
            self.done = True
            return None, None, self.done, {}
        
        for i, a in enumerate(actions):
            stock = self.stocks[i]
            close = self.close_prices[f"{stock}_close"].loc[date]
            n_shares = self.pos_size / close

            if a == 1 and self.current_balance >= self.pos_size:
                self.current_balance -= self.pos_size
                self.holdings[stock]["shares"] += n_shares
                self.holdings[stock]["balance"] += self.pos_size
                reward += 0

            elif a == 2 and self.holdings[stock]["shares"] > 0:
                sell_value = close * self.holdings[stock]["shares"]
                self.holdings[stock]["shares"] = 0
                self.holdings[stock]["balance"] = 0
                self.current_balance += sell_value
                reward += sell_value
                
            else:
                reward += 0

        obs = self.data.loc[date].values

        return obs, reward, self.done, {}
    
    def reset(self):
        self.current_step = 0
        self.current_balance = self.starting_balance
        self.total_profit = 0
        self.holdings = {stock: {"shares": 0, "balance": 0} for stock in self.stocks}
        self.done = False
        return self.data.loc[self.data.index[self.current_step]].values


In [12]:
env = TradeEnv(data, starting_balance=100_000)

obs = env.reset()

while not env.done:
    actions = np.random.choice([0, 1, 2], size=env.n_stocks)
    obs, reward, env.done, _ = env.step(actions)
    print(f"Actions: {[f'{stock}:{action}' for stock, action in zip(actions, env.stocks)]}, Reward: {reward}, Current Balance: {env.current_balance}")

Actions: ['1:AGG', '2:DBC', '2:VTI', '1:^VIX'], Reward: 0, Current Balance: 80000
Actions: ['0:AGG', '1:DBC', '1:VTI', '0:^VIX'], Reward: 0, Current Balance: 60000
Actions: ['0:AGG', '1:DBC', '0:VTI', '2:^VIX'], Reward: 10094.501420474913, Current Balance: 60094.50142047492
Actions: ['0:AGG', '1:DBC', '2:VTI', '2:^VIX'], Reward: 9938.952778868505, Current Balance: 60033.454199343425
Actions: ['2:AGG', '0:DBC', '1:VTI', '1:^VIX'], Reward: 9985.740157886385, Current Balance: 50019.194357229804
Actions: ['0:AGG', '0:DBC', '1:VTI', '1:^VIX'], Reward: 0, Current Balance: 30019.194357229804
Actions: ['2:AGG', '2:DBC', '0:VTI', '2:^VIX'], Reward: 49414.238272413946, Current Balance: 79433.43262964375
Actions: ['1:AGG', '1:DBC', '1:VTI', '1:^VIX'], Reward: 0, Current Balance: 39433.43262964375
Actions: ['0:AGG', '0:DBC', '1:VTI', '1:^VIX'], Reward: 0, Current Balance: 19433.43262964375
Actions: ['0:AGG', '2:DBC', '2:VTI', '2:^VIX'], Reward: 69527.40636207683, Current Balance: 88960.83899172058

In [10]:
env.holdings

{'AGG': 0, 'DBC': 0, 'VTI': 0, '^VIX': 0}

In [79]:
a.step([1, 2, 3, 4])

57.90574264526367
23.573301315307617
47.711334228515625
11.989999771118164


In [62]:
a.current_step

11