In [1]:
import numpy as np
import pandas as pd
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.evaluation import evaluate_policy

In [2]:
from Simulator.Exchange import Exchange
from Simulator.Strategy import Strategy
from Simulator.Order import Order
from Simulator.OrderState import OrderState
from TradeEnv.TradeGym import TradeEnv
from Simulator.InverseInstrument import InverseInstrument

In [3]:
from IPython.display import clear_output
import glob

files = glob.glob("*.csv.gz")

model = None

for file in files[:3]:
    for j in range(0, 100):
        df = pd.read_csv(file, header=0, index_col='timestamp', parse_dates=['timestamp'])
        row_count = df.shape[0]
        index = np.random.randint(low=0, high=row_count-7200)
        print("iteration: ", j, "count: ", row_count, "index: ", index)
        length = row_count - index + 1
        instrument = InverseInstrument("BTC-PERPETUAL", 0.5, 10, 0, 0.0005)
        exchange = Exchange(df.iloc[index:, :])
        strategy = Strategy(instrument, exchange, 0.002, 0.0002)
        env = TradeEnv(strategy, "human")
    
        if model is None:
            model = RecurrentPPO("MlpLstmPolicy", env, verbose=0, gamma=.999, n_steps=120)
        else:
            model.set_env(env)
        
        model = model.learn(length, progress_bar=False)
        clear_output(True)

#vec_env = model.get_env()
#mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=20, warn=False)
#print(mean_reward)

model.save("ppo_recurrent")
del model # remove to demonstrate saving and loading

iteration:  0 count:  86392 index:  75009
{'balance': 0.002, 'trade_count': 131, 'trading_pnl_pct': -0.0, 'inventory_pnl_pct': -0.25, 'leverage': 3.31, 'reward': -3.56, 'steps': 1200}
{'balance': 0.002, 'trade_count': 247, 'trading_pnl_pct': -0.26, 'inventory_pnl_pct': -0.26, 'leverage': 3.32, 'reward': -3.58, 'steps': 2400}


KeyboardInterrupt: 

In [None]:
model = RecurrentPPO.load("ppo_recurrent")

df = pd.read_csv(files[3], header=0, index_col='timestamp', parse_dates=['timestamp'])
exchange = Exchange(df)
strategy = Strategy(instrument, exchange, 0.02, 0.0002)
env = TradeEnv(strategy, "human")
obs, info = env.reset()

# cell and hidden state of the LSTM
lstm_states = None

episode_start = 1
done = False
truncated = False
while not done and not truncated:
    action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_start, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)