# Construct a custom Environment for Pair Trading

Some examples on the market
* [custom env example](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb#scrollTo=RqxatIwPOXe_)
* [StockTradingEnv by Adam King](https://github.com/notadamking/Stock-Trading-Environment)
* [FinRL](https://github.com/AI4Finance-Foundation/FinRL)

Target is to construct a custom Env for pair trading

This env restrict the behaviour of RL learner to pair trading only

In [1]:
import warnings
warnings.filterwarnings('ignore')

import os, csv, shutil
import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt

from datetime import date
from envs.env_gridsearch import kellycriterion
from sklearn.model_selection import train_test_split
from stable_baselines3.common.vec_env import DummyVecEnv
from utils.read2df import read2df, unify_dfs
from utils.rlmetrics import get_return, get_metrics
from utils.clearlogs import clear_logs
from envs.env_rl_restrict import PairTradingEnv
from params import *

from stable_baselines3 import PPO, A2C, DQN
import quantstats as qs

folder_path = f"result/rl-restrict-thres"
os.makedirs(folder_path, exist_ok=True)

Load data from `preliminaries.ipynb`

In [2]:
import pickle

with open('result/cointncorr.pickle', 'rb') as pk:
    data = pickle.load(pk)

bestres_pickle = 'result/gridsearch/best_res.pickle'
if os.path.exists(bestres_pickle):
    with open(bestres_pickle, 'rb') as pk:
        best_profit, best_params = pickle.load(pk)
else:
    print("pickle is not found, please execute `trade_gridsearch.ipynb` first")

dfs = read2df(symbols=data[0], freqs={data[1]: freqs[data[1]]})

In [3]:
df = unify_dfs(dfs, symbols=data[0], period=best_params['period'])

In [5]:
df

Unnamed: 0,time,close_ETHEUR,itvl,datetime,close_ETHGBP,spread,zscore
0,1592560859999,0.004827,1m,2020-06-19 10:00:59.999,0.005422,-0.000594,0.000000
1,1592560919999,0.004827,1m,2020-06-19 10:01:59.999,0.005412,-0.000584,0.000000
2,1592560979999,0.004827,1m,2020-06-19 10:02:59.999,0.005412,-0.000584,0.000000
3,1592561039999,0.004854,1m,2020-06-19 10:03:59.999,0.005412,-0.000557,0.000000
4,1592561099999,0.004854,1m,2020-06-19 10:04:59.999,0.005412,-0.000557,0.000000
...,...,...,...,...,...,...,...
1812156,1701388559999,0.000530,1m,2023-11-30 23:55:59.999,0.000615,-0.000084,-0.628750
1812157,1701388619999,0.000530,1m,2023-11-30 23:56:59.999,0.000615,-0.000084,-0.603341
1812158,1701388679999,0.000530,1m,2023-11-30 23:57:59.999,0.000615,-0.000084,-0.623742
1812159,1701388739999,0.000530,1m,2023-11-30 23:58:59.999,0.000615,-0.000084,-0.613500


In [None]:
df0 = dfs[0][dfs[0]['tic']==data[0][0]].reset_index(drop=True)
df1 = dfs[0][dfs[0]['tic']==data[0][1]].reset_index(drop=True)

# Because we want to calculate profit based on BTC. Hence the price need to be changed.
df0 = df0[['time', 'close', 'volume', 'tic', 'itvl', 'datetime']]
df0['close'] = df0['close'].apply(lambda x: 1/x)

df1 = df1[['time', 'close', 'volume', 'tic', 'itvl', 'datetime']]
df1['close'] = df1['close'].apply(lambda x: 1/x)

Set data before `trade_data` as training data, after `trade_data` is trade_data

In [None]:
train0 = df0[df0['datetime'] < trade_date]
train1 = df1[df1['datetime'] < trade_date]

test0 = df0[df0['datetime'] >= trade_date]
test1 = df1[df1['datetime'] >= trade_date]

max_train_len = len(train0)-best_params['period']-1
print(f"The length of our training data: {len(train0)}")

## Check with baselin3 `env_checker`

Check if the env meets the requirements of `stable_baseline3`

In [None]:
from stable_baselines3.common.env_checker import check_env
# > UserWarning: The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the action using a wrapper.
# Baseline 3 does not support Dict/Tuple action spaces....only Box Discrete MultiDiscrete MultiBinary
# Is there another way to achieve the same functionality?

env = PairTradingEnv(train0, train1)
check_env(env)

## Do an experimental run with randomly generated actions

In [None]:
# env = PairTradingEnv(train0, train1, tc=0.002, verbose=1, model=f"{folder_path}/networth_experiment.csv")
# obs, _ = env.reset()

# print(f"observation_space: {env.observation_space}")
# print(f"action_space: {env.action_space}")
# print(f"action_space.sample: {env.action_space.sample()}")

# n_steps = 20

# for step in range(n_steps):
#     obs, reward, terminated, truncated, info = env.step(action=env.action_space.sample())
#     done = terminated or truncated
#     env.render()
#     if done:
#         break

## Models from stable_baselines3

Delete existing tensorboard logs

In [None]:
log_path = f"logs/restrict_thres/"
clear_logs(log_path)

# Read more about tensorboard
# https://github.com/tensorflow/tensorboard/blob/master/README.md
# https://www.tensorflow.org/tensorboard/get_started

Train with training data

In [None]:
'''PPO'''

env = PairTradingEnv(train0, train1, tc=0.00, model=f"{folder_path}/networth_ppo.csv")

model_ppo = PPO("MultiInputPolicy", env, verbose=0, gamma=1, batch_size =3000, learning_rate=0.003, tensorboard_log=log_path)
model_ppo.learn(total_timesteps=max_train_len/20)
model_ppo.save(f"{folder_path}/ppo_pairtrading")

In [None]:
# '''A2C'''

# env = PairTradingEnv(train0, train1, tc=0.00, model=f"{folder_path}/networth_a2c.csv")

# model_a2c = A2C("MultiInputPolicy", env, verbose=0, gamma=1, tensorboard_log=log_path)
# model_a2c.learn(total_timesteps=max_train_len/10)
# model_a2c.save(f"{folder_path}/a2c_pairtrading")

In [None]:
# '''DQN'''

# env = PairTradingEnv(train0, train1, tc=0.00, model=f"{folder_path}/networth_dqn.csv")

# model_dqn = DQN("MultiInputPolicy", env, verbose=0, gamma=1, tensorboard_log=log_path)
# model_dqn.learn(total_timesteps=max_train_len/10)
# model_dqn.save(f"{folder_path}/dqn_pairtrading")

## Use the model on Test data

In [None]:
# # del model_ppo, model_a2c, model_dqn

# model_ppo = PPO.load(f"{folder_path}/ppo_pairtrading")
# model_a2c = A2C.load(f"{folder_path}/a2c_pairtrading")
# model_dqn = DQN.load(f"{folder_path}/dqn_pairtrading")

In [None]:
try:
    os.remove(f"{folder_path}/networth_ppo.csv")
except OSError:
    pass

env = PairTradingEnv(test0, test1, tc=0.00, verbose=1, model=f"{folder_path}/networth_ppo.csv")
obs, _ = env.reset()

while True:
    action, _states = model_ppo.predict(obs)
    observation, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    env.render()
    if terminated:
        print("Test Finished!")
        break
    elif truncated:
        print("bankrupted!")
        break

In [None]:
# try:
#     os.remove(f"{folder_path}/networth_a2c.csv")
# except OSError:
#     pass

# env = PairTradingEnv(test0, test1, tc=0.00, model=f"{folder_path}/networth_a2c.csv")
# obs, _ = env.reset()

# while True:
#     action, _states = model_a2c.predict(obs)
#     observation, reward, terminated, truncated, info = env.step(action)
#     done = terminated or truncated
#     env.render()
#     if terminated:
#         print("Test Finished!")
#         break
#     elif truncated:
#         print("bankrupted!")
#         break

In [None]:
# try:
#     os.remove(f"{folder_path}/networth_dqn.csv")
# except OSError:
#     pass

# env = PairTradingEnv(test0, test1, tc=0.00, verbose=0, model=f"{folder_path}/networth_dqn.csv")
# obs, _ = env.reset()

# while True:
#     action, _states = model_dqn.predict(obs)
#     observation, reward, terminated, truncated, info = env.step(action)
#     done = terminated or truncated
#     env.render()
#     if terminated:
#         print("Test Finished!")
#         break
#     elif truncated:
#         print("bankrupted!")
#         break

### Analyze with Quanstats

In [None]:
os.remove(f"{folder_path}/networth_experiment.csv") if os.path.exists(f"{folder_path}/networth_experiment.csv") else None

csv_files = [file for file in os.listdir(folder_path) if file.endswith('.csv')]

best_res, best_model = None, None
for file_name in csv_files:
    file_path = os.path.join(folder_path, file_name)
    
    with open(file_path, 'r') as csv_file:
        csv_reader = csv.reader(csv_file)
        
        # Loop through the lines in the CSV file
        last_line = None
        for row in csv_reader:
            last_line = row  # Update last_line with the current row
    
    if best_res is None or float(best_res) < float(last_line[1]):
        best_res = last_line[1]
        best_model = file_name

    print(f"The ending capital of {file_name} is {last_line[0:2]}")

print(f"The best model is {best_model}")

In [None]:
best_return = get_return(f'{folder_path}/{best_model}')
get_metrics(best_return)

len(best_return)

# Some graphs

In [None]:
fig, ax = plt.subplots()

ax.set_title("Profit and Loss")
ax.plot(best_return['pnl'])

plt.show()

In [None]:
# fig, (ax1, ax2) = plt.subplots(2, 1)

# ax1.plot(best_return['return'])

In [None]:
qs.reports.full(best_return['returns'])