In [1]:
import numpy as np
import pandas as pd
import gymnasium as gym
from sb3_contrib import RecurrentPPO
from stable_baselines3 import PPO
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env

In [2]:
from Simulator.Exchange import Exchange
from Simulator.Strategy import Strategy
from Simulator.Order import Order
from Simulator.OrderState import OrderState
from TradeEnv.TradeGym import TradeEnv
from Simulator.InverseInstrument import InverseInstrument

In [3]:
import os

files = []
dirname = "./"
for filename in os.listdir(dirname):
    if "BTC" in filename and ".csv.gz" in filename:
        files.append(os.path.join(dirname, filename))
        print(files[-1])

./data_2024-01-01_BTC-PERPETUAL.csv.gz
./data_2024-02-01_BTC-PERPETUAL.csv.gz
./data_2024-03-01_BTC-PERPETUAL.csv.gz
./data_2024-04-01_BTC-PERPETUAL.csv.gz


In [4]:
from typing import Callable

def make_env(df, rank: int, seed: int = 0) -> Callable:
    def _init() -> gym.Env:
        instrument = InverseInstrument("BTC-PERPETUAL", 0.5, 10, 0, 0.0005)
        exchange = Exchange(df)
        strategy = Strategy(instrument, exchange, 0.02, 0.02)
        env = TradeEnv(strategy, 3600, "human")

        env.reset(seed=seed + rank)
        return env

    set_random_seed(seed)
    return _init

In [5]:
from IPython.display import clear_output

model = None
files.sort()

for file in files[:2]:
    for z in range(20):
        print(file)
        df = pd.read_csv(file, header=0, index_col='timestamp', parse_dates=['timestamp'])
        rdf_len = np.random.randint(low=3600, high=18000)
        length = df.shape[0] // rdf_len
        for i in range(length):
            rdf = df[i*rdf_len:(i+1)*rdf_len]
            if rdf.shape[0] < 300:
                continue
                
            print("count: ", (i+1)*rdf_len)
            env =  make_env(rdf, 0)() #SubprocVecEnv([make_env(df, i) for i in range(4)])    
            if model is None:
                model = PPO("MlpPolicy", 
                             env, 
                             verbose=0, gamma=.9, 
                             n_steps=rdf_len, batch_size=rdf_len, learning_rate=0.01)
            else:
                model.set_env(env)
        
            model = model.learn(rdf_len, progress_bar=True)
            clear_output(True)

model.save("ppo_recurrent")
del model # remove to demonstrate saving and loading

count:  16710


Output()

KeyboardInterrupt: 

In [None]:
model = RecurrentPPO.load("ppo_recurrent")

df = pd.read_csv(files[3], header=0, index_col='timestamp', parse_dates=['timestamp'])
exchange = Exchange(df)
strategy = Strategy(instrument, exchange, 0.025, 0.0002)
env = TradeEnv(strategy, "human")
obs, info = env.reset()

# cell and hidden state of the LSTM
lstm_states = None

episode_start = 1
done = False
truncated = False
while not done and not truncated:
    action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_start, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)