# Construct a custom Environment for Pair Trading

Some examples on the market
* [custom env example](https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/5_custom_gym_env.ipynb#scrollTo=RqxatIwPOXe_)
* [StockTradingEnv by Adam King](https://github.com/notadamking/Stock-Trading-Environment)
* [FinRL](https://github.com/AI4Finance-Foundation/FinRL)

Target is to construct a custom Env for pair trading

This env restrict the behaviour of RL learner to pair trading only

In [6]:
import warnings
warnings.filterwarnings('ignore')

import os
import csv
import numpy as np
import pandas as pd
import statsmodels.api as sm

from datetime import date
from envs.env_gridsearch import kellycriterion
from sklearn.model_selection import train_test_split
from stable_baselines3.common.vec_env import DummyVecEnv
from utils.read2df import read2df
from envs.env_rl_restrict import PairTradingEnv
from params import *

from stable_baselines3 import PPO, A2C, DQN

os.makedirs("result/rl-restrict", exist_ok=True)

# for root, dirs, files in os.walk(f"result/rl-restrict/"):
#     for file in files:
#         os.remove(os.path.join(root, file))

# PERIOD = 150 # Only look at the current price
# CASH = 10000
# ISKELLY = True
# OPEN_THRE = 6.0
# CLOS_THRE = 0.6

Load data from `preliminaries.ipynb`

In [2]:
import pickle

with open('result/cointncorr.pickle', 'rb') as pk:
    data = pickle.load(pk)

dfs = read2df(symbols=data[0], freqs={data[1]: freqs[data[1]]})

df0 = dfs[0][dfs[0]['tic']==data[0][0]].reset_index(drop=True)
df1 = dfs[0][dfs[0]['tic']==data[0][1]].reset_index(drop=True)

Set data before `trade_data` as training data, after `trade_data` is trade_data

In [3]:
train0 = df0[df0['datetime'] < trade_date]
train1 = df1[df1['datetime'] < trade_date]

test0 = df0[df0['datetime'] >= trade_date]
test1 = df1[df1['datetime'] >= trade_date]

print(f"The length of our training data: {len(train0)}")

The length of our training data: 1589703


## Check with baselin3 `env_checker`

Check if the env meets the requirements of `stable_baseline3`

In [4]:
from stable_baselines3.common.env_checker import check_env
# > UserWarning: The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the action using a wrapper.
# Baseline 3 does not support Dict/Tuple action spaces....only Box Discrete MultiDiscrete MultiBinary
# Is there another way to achieve the same functionality?

env = PairTradingEnv(train0, train1)
check_env(env)

## Do a test run with random generated actions

In [5]:
env = PairTradingEnv(train0, train1, tc=0, model="test")
obs, _ = env.reset()

print(f"observation_space: {env.observation_space}")
print(f"action_space: {env.action_space}")
print(f"action_space.sample: {env.action_space.sample()}")

n_steps = 20

for step in range(n_steps):
    obs, reward, terminated, truncated, info = env.step(action=env.action_space.sample())
    done = terminated or truncated
    env.render()
    if done:
        break

observation_space: Dict('compare_clos_thre': Discrete(3), 'compare_open_thre': Discrete(3), 'position': Discrete(3), 'zscore': Box(-inf, inf, (1,), float64))
action_space: Discrete(4)
action_space.sample: 1


## Models from stable_baselines3

Train with training data

In [7]:
'''PPO'''

env = PairTradingEnv(train0, train1, tc=0, model="ppo")

model_ppo = PPO("MultiInputPolicy", env, verbose=0, tensorboard_log="logs")
model_ppo.learn(total_timesteps=30000)
model_ppo.save("result/rl-restrict/ppo_pairtrading")

In [8]:
'''A2C'''

from stable_baselines3 import A2C

env = PairTradingEnv(train0, train1, tc=0, model="a2c")

model_a2c = A2C("MultiInputPolicy", env, verbose=0)
model_a2c.learn(total_timesteps=30000)
model_a2c.save("result/rl-restrict/a2c_pairtrading")

In [9]:
'''DQN'''

from stable_baselines3 import DQN

env = PairTradingEnv(train0, train1, tc=0, model="dqn")

model_dqn = DQN("MultiInputPolicy", env, verbose=0)
model_dqn.learn(total_timesteps=30000)
model_dqn.save("result/rl-restrict/dqn_pairtrading")

## Use the model on Test data

In [7]:
# del model_ppo, model_a2c, model_dqn

model_ppo = PPO.load("result/rl-restrict/ppo_pairtrading")
model_a2c = A2C.load("result/rl-restrict/a2c_pairtrading")
model_dqn = DQN.load("result/rl-restrict/dqn_pairtrading")

In [8]:
env = PairTradingEnv(test0, test1, tc=0, model="ppo", isKelly=True)
obs, _ = env.reset()

while True:
    action, _states = model_ppo.predict(obs)
    observation, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    env.render()
    if terminated:
        print("Test Finished!")
        break
    elif truncated:
        print("bankrupted!")
        break

Test Finished!


In [9]:
env = PairTradingEnv(test0, test1, tc=0, model="a2c")
obs, _ = env.reset()

while True:
    action, _states = model_a2c.predict(obs)
    observation, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    env.render()
    if terminated:
        print("Test Finished!")
        break
    elif truncated:
        print("bankrupted!")
        break

Test Finished!


In [10]:
env = PairTradingEnv(test0, test1, tc=0, model="dqn")
obs, _ = env.reset()

while True:
    action, _states = model_dqn.predict(obs)
    observation, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    env.render()
    if terminated:
        print("Test Finished!")
        break
    elif truncated:
        print("bankrupted!")
        break

Test Finished!


### Analyze with PyFolio

In [11]:
folder_path = f"result/rl-restrict/"
os.remove(f"{folder_path}networth_test.csv") if os.path.exists(f"{folder_path}networth_test.csv") else None
csv_files = [file for file in os.listdir(folder_path) if file.endswith('.csv')]

best_res, best_model = None, None
for file_name in csv_files:
    file_path = os.path.join(folder_path, file_name)
    
    with open(file_path, 'r') as csv_file:
        csv_reader = csv.reader(csv_file)
        
        # Loop through the lines in the CSV file
        last_line = None
        for row in csv_reader:
            last_line = row  # Update last_line with the current row
    
    if best_res is None or float(best_res) < float(last_line[1]):
        best_res = last_line[1]
        best_model = file_name

    print(f"The ending capital of {file_name} is {last_line[0:2]}")

print(f"The best model is {best_model}")

The ending capital of networth_a2c.csv is ['2023-10-31 23:59:59.999000', '9989.969862670621']
The ending capital of networth_dqn.csv is ['2023-10-31 23:59:59.999000', '9955.08075602882']
The ending capital of networth_ppo.csv is ['2023-10-31 23:59:59.999000', '9907.432616865963']
The best model is networth_a2c.csv


In [12]:
def get_return(networthcsv):
    returns = pd.read_csv(networthcsv, names=['datetime', 'returns', "action", "position", "order0", "order1"])
    returns['datetime'] = pd.to_datetime(returns['datetime'])
    returns.set_index('datetime', inplace=True)
    res_daily = returns.resample('D').mean()
    res_daily['returns'] = res_daily['returns'].pct_change()
    res_daily = res_daily.dropna()
    return res_daily

best_return = get_return(f'result/rl-restrict/{best_model}')

In [None]:
best_df = pd.read_csv(f'result/rl-restrict/{best_model}', names=["datetime", "networth"])

In [17]:
best_df

Unnamed: 0,datetime,networth
0,2023-03-19 08:48:59.999000,10000.000000
1,2023-03-19 08:49:59.999000,10000.000000
2,2023-03-19 08:50:59.999000,10000.000000
3,2023-03-19 08:51:59.999000,10000.000000
4,2023-03-19 08:52:59.999000,10000.000000
...,...,...
654014,2023-10-31 23:55:59.999000,9999.283894
654015,2023-10-31 23:56:59.999000,9999.283894
654016,2023-10-31 23:57:59.999000,9999.283894
654017,2023-10-31 23:58:59.999000,9999.283894


In [18]:
# import matplotlib.pyplot as plt

# plt.plot(best_df['datetime'], best_df['networth'])

In [19]:
# # Calculate total orders count
# total_orders_count = best_df.shape[0]

# # Calculate won orders count
# won_orders_count = best_df[(best_df['order1'] == 1) & (best_df['position'] == 0)].shape[0]

# # Calculate lost orders count
# lost_orders_count = best_df[(best_df['order1'] == 2) & (best_df['position'] == 0)].shape[0]

# # Calculate Win/Loss order ratio
# win_loss_order_ratio = won_orders_count / lost_orders_count if lost_orders_count != 0 else np.inf

# # Calculate Avg order pnl
# avg_order_pnl = best_df['order0'].mean()

# # Calculate Avg order pnl won
# avg_order_pnl_won = best_df[(best_df['order1'] == 1) & (best_df['position'] == 0)]['order0'].mean()

# # Calculate Avg order pnl lost
# avg_order_pnl_lost = best_df[(best_df['order1'] == 2) & (best_df['position'] == 0)]['order0'].mean()

# # Calculate Avg long order pnl
# avg_long_order_pnl = best_df[(best_df['order1'] == 1) & (best_df['position'] == 2)]['order0'].mean()

# # Calculate Avg short order pnl
# avg_short_order_pnl = best_df[(best_df['order1'] == 1) & (best_df['position'] == 0)]['order1'].mean()

# # Print the calculated indices
# print("Total orders count:", total_orders_count)
# print("Won orders count:", won_orders_count)
# print("Lost orders count:", lost_orders_count)
# print("Win/Loss order ratio:", win_loss_order_ratio)
# print("Avg order pnl:", avg_order_pnl)
# print("Avg order pnl won:", avg_order_pnl_won)
# print("Avg order pnl lost:", avg_order_pnl_lost)
# print("Avg long order pnl:", avg_long_order_pnl)
# print("Avg short order pnl:", avg_short_order_pnl)


In [20]:
# import pyfolio

# pyfolio.tears.create_full_tear_sheet(best_return['returns'])