<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Imports" data-toc-modified-id="Imports-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Imports</a></span></li><li><span><a href="#Read-Data" data-toc-modified-id="Read-Data-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Read Data</a></span><ul class="toc-item"><li><span><a href="#Quick-and-simple-EDA" data-toc-modified-id="Quick-and-simple-EDA-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Quick and simple EDA</a></span></li></ul></li><li><span><a href="#Feature-engineering" data-toc-modified-id="Feature-engineering-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Feature engineering</a></span><ul class="toc-item"><li><span><a href="#Datetime-semantics" data-toc-modified-id="Datetime-semantics-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Datetime semantics</a></span><ul class="toc-item"><li><span><a href="#Year-based" data-toc-modified-id="Year-based-3.1.1"><span class="toc-item-num">3.1.1&nbsp;&nbsp;</span>Year based</a></span></li><li><span><a href="#Month-based" data-toc-modified-id="Month-based-3.1.2"><span class="toc-item-num">3.1.2&nbsp;&nbsp;</span>Month based</a></span></li><li><span><a href="#Day-based" data-toc-modified-id="Day-based-3.1.3"><span class="toc-item-num">3.1.3&nbsp;&nbsp;</span>Day based</a></span></li></ul></li><li><span><a href="#Make-trade-values-stationary" data-toc-modified-id="Make-trade-values-stationary-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Make trade values stationary</a></span><ul class="toc-item"><li><span><a href="#Price-change" data-toc-modified-id="Price-change-3.2.1"><span class="toc-item-num">3.2.1&nbsp;&nbsp;</span>Price change</a></span></li><li><span><a href="#Volatility" data-toc-modified-id="Volatility-3.2.2"><span class="toc-item-num">3.2.2&nbsp;&nbsp;</span>Volatility</a></span></li><li><span><a href="#Volume-change" data-toc-modified-id="Volume-change-3.2.3"><span class="toc-item-num">3.2.3&nbsp;&nbsp;</span>Volume change</a></span></li></ul></li><li><span><a href="#Graphs" data-toc-modified-id="Graphs-3.3"><span class="toc-item-num">3.3&nbsp;&nbsp;</span>Graphs</a></span></li></ul></li><li><span><a href="#Light-GBM" data-toc-modified-id="Light-GBM-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Light GBM</a></span></li><li><span><a href="#The-reinforcement-learning" data-toc-modified-id="The-reinforcement-learning-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>The reinforcement learning</a></span><ul class="toc-item"><li><span><a href="#Graph-visualisation" data-toc-modified-id="Graph-visualisation-5.1"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Graph visualisation</a></span></li><li><span><a href="#The-gym-environment" data-toc-modified-id="The-gym-environment-5.2"><span class="toc-item-num">5.2&nbsp;&nbsp;</span>The gym environment</a></span></li><li><span><a href="#Run-it" data-toc-modified-id="Run-it-5.3"><span class="toc-item-num">5.3&nbsp;&nbsp;</span>Run it</a></span></li></ul></li></ul></div>

# Imports

In [1]:
import pandas as pd
import pyarrow as pa
import numpy as np
import datatable as dt

from tqdm import tqdm
from datetime import datetime, timedelta

In [2]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:85% !important; }</style>"))

# Read Data

In [3]:
# df_all = pd.read_csv("train.csv")
df_all = dt.fread("train.csv").to_pandas()
df_all.shape

(24236806, 10)

In [4]:
df_all.head()

Unnamed: 0,timestamp,Asset_ID,Count,Open,High,Low,Close,Volume,VWAP,Target
0,1514764860,2,40.0,2376.58,2399.5,2357.14,2374.59,19.233005,2373.116392,-0.004218
1,1514764860,0,5.0,8.53,8.53,8.53,8.53,78.38,8.53,-0.014399
2,1514764860,1,229.0,13835.194,14013.8,13666.11,13850.176,31.550062,13827.062093,-0.014643
3,1514764860,5,32.0,7.6596,7.6596,7.6567,7.6576,6626.71337,7.657713,-0.013922
4,1514764860,7,5.0,25.92,25.92,25.874,25.877,121.08731,25.891363,-0.008264


In [5]:
df_assets = pd.read_csv("asset_details.csv")
df_assets.shape

(14, 3)

In [6]:
df_assets.sort_values(["Asset_ID"]).head(20)

Unnamed: 0,Asset_ID,Weight,Asset_Name
1,0,4.304065,Binance Coin
2,1,6.779922,Bitcoin
0,2,2.397895,Bitcoin Cash
10,3,4.406719,Cardano
13,4,3.555348,Dogecoin
3,5,1.386294,EOS.IO
5,6,5.894403,Ethereum
4,7,2.079442,Ethereum Classic
11,8,1.098612,IOTA
6,9,2.397895,Litecoin


## Quick and simple EDA

In [7]:
df_all.isna().sum()

timestamp         0
Asset_ID          0
Count             0
Open              0
High              0
Low               0
Close             0
Volume            0
VWAP              9
Target       750338
dtype: int64

In [8]:
columns = df_all.columns
for col in sorted(columns):
    print(f"Col: {col.ljust(40, ' '):40} Type: {df_all[col].dtype}")

Col: Asset_ID                                 Type: int32
Col: Close                                    Type: float64
Col: Count                                    Type: float64
Col: High                                     Type: float64
Col: Low                                      Type: float64
Col: Open                                     Type: float64
Col: Target                                   Type: float64
Col: VWAP                                     Type: float64
Col: Volume                                   Type: float64
Col: timestamp                                Type: int32


In [9]:
df_all.describe()

Unnamed: 0,timestamp,Asset_ID,Count,Open,High,Low,Close,Volume,VWAP,Target
count,24236810.0,24236810.0,24236810.0,24236810.0,24236810.0,24236810.0,24236810.0,24236810.0,24236800.0,23486470.0
mean,1577120000.0,6.292544,286.4593,1432.64,1436.35,1429.568,1432.64,286853.0,,7.121752e-06
std,33233500.0,4.091861,867.3982,6029.605,6039.482,6020.261,6029.611,2433935.0,,0.005679042
min,1514765000.0,0.0,1.0,0.0011704,0.001195,0.0002,0.0011714,-0.3662812,-inf,-0.5093509
25%,1549011000.0,3.0,19.0,0.26765,0.26816,0.2669,0.2676484,141.0725,0.2676368,-0.001694354
50%,1578372000.0,6.0,64.0,14.2886,14.3125,14.263,14.2892,1295.415,14.28769,-4.289844e-05
75%,1606198000.0,9.0,221.0,228.8743,229.3,228.42,228.8729,27297.64,228.8728,0.00160152
max,1632182000.0,13.0,165016.0,64805.94,64900.0,64670.53,64808.54,759755400.0,inf,0.9641699


# Feature engineering

Initially this is going to be a bunch of *traditionally useful* financial features. In the paraphrased words of our boy Ernest Chan, "throw heaps of features at your models during prototyping, let the model decide what it thinks is important, it is probably smarter than you."

^This is obviously a joke taken to extremes but you get the point. Who are we to decide what is a "good feature", if we want to find some form of market inefficieny then why not try everything and use feature selection to trim down our list for us.

According to this, https://arxiv.org/abs/2005.12483, LIME is just as promising as SHAP so I am eager to explore this. If it fails we can always fall back on good old SHAP

***Disclaimer***: To be frank, the choice for the rolling window size was totally arbitrary

In [10]:
df_all.head()

Unnamed: 0,timestamp,Asset_ID,Count,Open,High,Low,Close,Volume,VWAP,Target
0,1514764860,2,40.0,2376.58,2399.5,2357.14,2374.59,19.233005,2373.116392,-0.004218
1,1514764860,0,5.0,8.53,8.53,8.53,8.53,78.38,8.53,-0.014399
2,1514764860,1,229.0,13835.194,14013.8,13666.11,13850.176,31.550062,13827.062093,-0.014643
3,1514764860,5,32.0,7.6596,7.6596,7.6567,7.6576,6626.71337,7.657713,-0.013922
4,1514764860,7,5.0,25.92,25.92,25.874,25.877,121.08731,25.891363,-0.008264


## Datetime semantics

In [11]:
df_all["date_time"] = pd.to_datetime(df_all["timestamp"], unit="s")

### Year based
If assuming cyclical behaviour then we would need more cases of each cycle.

Because we have only collected ~8 months of training data we will have less than one cycle making this yearly data almost useless

In [12]:
df_all["year"] = df_all["date_time"].dt.year

In [13]:
df_all["quarter"] = df_all["date_time"].dt.quarter

In [14]:
df_all["month_of_year"] = df_all["date_time"].dt.month

In [15]:
df_all["week_of_year"] = df_all["date_time"].dt.isocalendar().week

### Month based

Month based data is assuming "Fund Flow" interactions where you have large firms re-balancing on a monthly schedule

This might not be part of crypto but could prove valuable anyways

In [16]:
df_all["day_of_month"] = df_all["date_time"].dt.day

In [17]:
df_all["day_of_week"] = df_all["date_time"].dt.dayofweek

### Day based

Day based data will show the influence each region has on the price. 

Because crypto is traded 24/7 it doesn't have a traditional market open and close. However each region around the world still needs to sleep and work so you will end up getting cyclical regional activity.

Each region has different risks and economical situations so you should get different behaviours.

You might also catch algorithmic trading based on certain times of day

In [18]:
df_all["hour_of_day"] = df_all["date_time"].dt.hour

In [19]:
df_all["minute_of_hour"] = df_all["date_time"].dt.minute

In [20]:
df_all["datem"] = df_all["date_time"].dt.date

In [21]:
df_all.tail()

Unnamed: 0,timestamp,Asset_ID,Count,Open,High,Low,Close,Volume,VWAP,Target,date_time,year,quarter,month_of_year,week_of_year,day_of_month,day_of_week,hour_of_day,minute_of_hour,datem
24236801,1632182400,9,775.0,157.181571,157.25,156.7,156.943857,4663.725,156.994319,,2021-09-21,2021,3,9,38,21,1,0,0,2021-09-21
24236802,1632182400,10,34.0,2437.065067,2438.0,2430.2269,2432.907467,3.97546,2434.818747,,2021-09-21,2021,3,9,38,21,1,0,0,2021-09-21
24236803,1632182400,13,380.0,0.09139,0.091527,0.09126,0.091349,2193732.0,0.091388,,2021-09-21,2021,3,9,38,21,1,0,0,2021-09-21
24236804,1632182400,12,177.0,0.282168,0.282438,0.281842,0.282051,182850.8,0.282134,,2021-09-21,2021,3,9,38,21,1,0,0,2021-09-21
24236805,1632182400,11,48.0,232.695,232.8,232.24,232.275,103.5123,232.569697,,2021-09-21,2021,3,9,38,21,1,0,0,2021-09-21


## Make trade values stationary

For timeseries data there is a concept called "stationarity", in short this is the act of transforming the data so that statistical properties such as mean, variance, autocorrelation, etc actually mean something for the future. 

Commonly this is done by converting changes in price to percentages and scaling them across your train data. The model can only act on information it has seen before, by making your data stationary it is possible that it can pick out recurring behaviour.

**NOTE**
For these sections we will need to separate out the assets and stitch them back together

### Price change

Here we have the raw price change and the price change as a percentage of the opening

We also want to look at a "smoother" price change in the form of a moving average. This is because typically when financial data is trending it isn't a nice line, it is a "spiky" line. By taking the moving average we lose some granularity but we can see the underlying momentum trend

In [22]:
default_rolling_window_size = 5

In [23]:
def determine_price_changes(df: pd.DataFrame, rolling_window_size: int) -> pd.DataFrame:
    df["price_change"] = df["Close"] - df["Open"]
    df["price_change_perc"] =  df["price_change"] / df["Open"]
    df["price_change_ma"] = df.rolling(rolling_window_size)["price_change"].mean()
    df["price_change_perc_smooth"] = df["price_change_ma"] / df["Open"]
    return df

In [24]:
parts = []

for asset_id in tqdm(df_assets["Asset_ID"].unique()):
    df_part = df_all.loc[df_all["Asset_ID"] == asset_id]
    df_part = determine_price_changes(df_part, default_rolling_window_size)
    parts.append(df_part)

df_all = pd.concat(parts)

df_part[[
    "Open",
    "Close",
    "price_change",
    "price_change_perc",
    "price_change_ma",
    "price_change_perc_smooth"
]].tail(20)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["price_change"] = df["Close"] - df["Open"]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["price_change_perc"] =  df["price_change"] / df["Open"]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["price_change_ma"] = df.rolling(rolling_window_size)["price_change"].mean()
A value is trying to 

Unnamed: 0,Open,Close,price_change,price_change_perc,price_change_ma,price_change_perc_smooth
24236530,0.206678,0.207477,0.000799,0.003865,5.1e-05,0.000249
24236544,0.207441,0.207581,0.00014,0.000672,0.000119,0.000572
24236558,0.20763,0.207149,-0.000481,-0.002317,1.4e-05,6.7e-05
24236572,0.207128,0.206629,-0.000499,-0.002408,4.3e-05,0.000206
24236586,0.206644,0.206746,0.000102,0.000495,1.2e-05,5.9e-05
24236600,0.206728,0.207441,0.000713,0.003447,-5e-06,-2.5e-05
24236614,0.207459,0.207374,-8.5e-05,-0.000409,-5e-05,-0.000241
24236628,0.207383,0.206935,-0.000448,-0.002158,-4.3e-05,-0.000209
24236642,0.206953,0.207718,0.000765,0.003696,0.000209,0.001012
24236656,0.207865,0.208161,0.000297,0.001426,0.000248,0.001195


In [25]:
df_all.shape

(24236806, 24)

### Volatility

Variance or "volatility" can be used to determine how stable the trading period was.

Again we are using a smooth apporach to see if it helps the model.

In [26]:
def determine_volatility(df: pd.DataFrame, rolling_window_size: int) -> pd.DataFrame:
    df["volatility"] = df["High"] - df["Low"]
    df["volatility_perc"] = df["volatility"] / df["Open"]
    df["volatility_ma"] = df.rolling(rolling_window_size)["volatility"].mean()
    df["volatility_perc_smooth"] = df["volatility_ma"] / df["Open"]
    return df

In [27]:
parts = []

for asset_id in tqdm(df_assets["Asset_ID"].unique()):
    df_part = df_all.loc[df_all["Asset_ID"] == asset_id]
    df_part = determine_volatility(df_part, default_rolling_window_size)
    parts.append(df_part)

df_all = pd.concat(parts)

df_part[[
    "High",
    "Low",
    "volatility",
    "volatility_perc",
    "volatility_ma",
    "volatility_perc_smooth"
]].tail(20)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["volatility"] = df["High"] - df["Low"]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["volatility_perc"] = df["volatility"] / df["Open"]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["volatility_ma"] = df.rolling(rolling_window_size)["volatility"].mean()
A value is trying to be set on a c

Unnamed: 0,High,Low,volatility,volatility_perc,volatility_ma,volatility_perc_smooth
24236530,0.2076,0.2066,0.001,0.004838,0.000766,0.003704
24236544,0.2078,0.20715,0.00065,0.003133,0.0008,0.003855
24236558,0.207806,0.207069,0.000737,0.00355,0.00086,0.00414
24236572,0.2074,0.206544,0.000856,0.004133,0.000771,0.003721
24236586,0.2069,0.2065,0.0004,0.001936,0.000729,0.003526
24236600,0.2075,0.2066,0.0009,0.004354,0.000709,0.003428
24236614,0.2076,0.20721,0.00039,0.00188,0.000657,0.003165
24236628,0.2076,0.206892,0.000708,0.003414,0.000651,0.003138
24236642,0.208,0.20689,0.00111,0.005364,0.000702,0.00339
24236656,0.2083,0.20773,0.00057,0.002742,0.000736,0.003539


In [28]:
df_all.shape

(24236806, 28)

### Volume change

Volume is another indicator of stability. If the volume suddenly jumps above the norm we might expect a change in market direction. 

For this reason we are looking at the volume change since last time period and another smoothed version

In [29]:
def determine_volume_change(df: pd.DataFrame, rolling_window_size: int) -> pd.DataFrame:
    df["last_volume"] = df["Volume"].shift(1)

    df["volume_change"] = df["Volume"] - df["last_volume"]
    df["volume_change_perc"] = df["volume_change"] / df["last_volume"]

    df["volume_change_ma"] = df.rolling(rolling_window_size)["volume_change"].mean()
    df["volume_change_perc_smooth"] = df["volume_change_ma"] / df["last_volume"]

    return df

In [30]:
parts = []

for asset_id in tqdm(df_assets["Asset_ID"].unique()):
    df_part = df_all.loc[df_all["Asset_ID"] == asset_id]
    df_part = determine_volume_change(df_part, default_rolling_window_size)
    parts.append(df_part)

df_all = pd.concat(parts)

df_part[[
    "Volume",
    "last_volume",
    "volume_change",
    "volume_change_perc",
    "volume_change_ma",
    "volume_change_perc_smooth"
]].tail(20)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["last_volume"] = df["Volume"].shift(1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["volume_change"] = df["Volume"] - df["last_volume"]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["volume_change_perc"] = df["volume_change"] / df["last_volume"]
A value is trying to be set on a copy of 

Unnamed: 0,Volume,last_volume,volume_change,volume_change_perc,volume_change_ma,volume_change_perc_smooth
24236530,926540.0,1138012.0,-211471.9,-0.185826,-58735.790395,-0.051613
24236544,1257670.0,926540.0,331130.3,0.357384,43514.338211,0.046964
24236558,793682.6,1257670.0,-463987.7,-0.368926,38118.060798,0.030308
24236572,856799.7,793682.6,63117.1,0.079524,-141825.404576,-0.178693
24236586,835833.5,856799.7,-20966.19,-0.02447,-60435.692869,-0.070537
24236600,1240643.0,835833.5,404809.2,0.484318,62820.530968,0.075159
24236614,649896.2,1240643.0,-590746.4,-0.476162,-121554.806373,-0.097977
24236628,2017260.0,649896.2,1367364.0,2.103972,244715.432664,0.376545
24236642,1120575.0,2017260.0,-896684.5,-0.444506,52755.110279,0.026152
24236656,1337674.0,1120575.0,217098.9,0.193739,100368.125938,0.089568


In [31]:
df_all.shape

(24236806, 33)

## Graphs

I am eager to get on to the model so will leave the feature engineering here for now, we can always return and add more to it

In [46]:
import plotly.express as px
import plotly.graph_objects as go

df_small = df_part.tail(2000).copy()

fig = go.Figure(
    data=go.Candlestick(
        x=df_small["date_time"],
        open=df_small["Open"],
        high=df_small["High"],
        low=df_small["Low"],
        close=df_small["Close"]
    )
)
fig.update_layout(
    height=1024,
    title=f"{df_assets[df_assets['Asset_ID']==asset_id]['Asset_Name'].values[0]}"
)
fig

# Light GBM

In [32]:
import lightgbm as lgb

# The reinforcement learning

## Graph visualisation

In [35]:
# import numpy as np
# import matplotlib
# import matplotlib.pyplot as plt
# import matplotlib.dates as mdates
#
# from matplotlib import style
#
# # finance module is no longer part of matplotlib
# # see: https://github.com/matplotlib/mpl_finance
# import mplfinance as mpf
#
#
#
# style.use('dark_background')
#
# VOLUME_CHART_HEIGHT = 0.33
#
# UP_COLOUR = '#27A59A'
# DOWN_COLOUR = '#EF534F'
# UP_TEXT_COLOUR = '#73D3CC'
# DOWN_TEXT_COLOUR = '#DC2C27'
#
#
#
# class CryptoTradingGraph:
#     """A crypto trading visualization using matplotlib made to render OpenAI gym environments"""
#
#
#     def __init__(self, df, title=None):
#         self.df = df
#         self.net_worths = np.zeros(len(df["open_time"]))
#
#         # Create a figure on screen and set the title
#         fig = plt.figure()
#         fig.suptitle(title)
#
#         # Create top subplot for net worth axis
#         self.net_worth_ax = plt.subplot2grid(
#             (6, 1),
#             (0, 0),
#             rowspan=2,
#             colspan=1
#         )
#
#         # Create bottom subplot for shared price/volume axis
#         self.price_ax = plt.subplot2grid(
#             (6, 1),
#             (2, 0),
#             rowspan=8,
#             colspan=1,
#             sharex=self.net_worth_ax
#         )
#
#         # Create a new axis for volume which shares its x-axis with price
#         self.volume_ax = self.price_ax.twinx()
#
#         # Add padding to make graph easier to view
#         plt.subplots_adjust(
#             left=0.11,
#             bottom=0.24,
#             right=0.90,
#             top=0.90,
#             wspace=0.2,
#             hspace=0
#         )
#
#         # Show the graph without blocking the rest of the program
#         plt.show(block=False)
#
#
#     def _render_net_worth(self, current_step, net_worth, step_range, dates):
#         # Clear the frame rendered last step
#         self.net_worth_ax.clear()
#
#         # Plot net worths
#         self.net_worth_ax.plot_date(
#             dates,
#             self.net_worths[step_range],
#             '-',
#             label="Net Worth"
#         )
#
#         # Show legend, which uses the label we defined for the plot above
#         self.net_worth_ax.legend()
#         legend = self.net_worth_ax.legend(loc=2, ncol=2, prop={"size": 8})
#         legend.get_frame().set_alpha(0.4)
#
#         last_date = self.df["open_time"].values[current_step]
#         last_net_worth = self.net_worths[current_step]
#
#         # Annotate the current net worth on the net worth graph
#         self.net_worth_ax.annotate(
#             f"{net_worth:,.2f}",
#             (last_date, last_net_worth),
#             xytext=(last_date, last_net_worth),
#             bbox=dict(boxstyle="round", fc='w', ec='k', lw=1),
#             color="black",
#             fontsize="small"
#         )
#
#         # Add space above and below min/max net worth
#         self.net_worth_ax.set_ylim(
#             min(self.net_worths[np.nonzero(self.net_worths)]) / 1.25,
#             max(self.net_worths) * 1.25
#         )
#
#
#     def _render_price(self, current_step, net_worth, dates, step_range):
#         self.price_ax.clear()
#
#         candlesticks = zip(
#             dates,
#             self.df["open"].values[step_range],
#             self.df["close"].values[step_range],
#             self.df["high"].values[step_range],
#             self.df["low"].values[step_range]
#         )
#
#         # Plot price using candlestick graph from mpl_finance
#         mpf.plot(
#             self.price_ax,
#             candlesticks,
#             width=1,
#             colorup=UP_COLOUR,
#             colordown=DOWN_COLOUR,
#             type="candle"
#         )
#
#         last_date = self.df["open_time"].values[current_step]
#         last_close = self.df["close"].values[current_step]
#         last_high = self.df["high"].values[current_step]
#
#         # Print the current price to the price axis
#         self.price_ax.annotate(
#             f"{last_close}:,.2f",
#             (last_date, last_close),
#             xytext=(last_date, last_high),
#             bbox=dict(boxstyle='round', fc='w', ec='k', lw=1),
#             color="black",
#             fontsize="small"
#         )
#
#         # Shift price axis up to give volume chart space
#         ylim = self.price_ax.get_ylim()
#         self.price_ax.set_ylim(
#             ylim[0] - (ylim[1] - ylim[0]) * VOLUME_CHART_HEIGHT,
#             ylim[1]
#         )
#
#
#     def _render_volume(self, current_step, net_worth, dates, step_range):
#         self.volume_ax.clear()
#
#         volume = np.array(self.df["volume"].values[step_range])
#
#         pos = self.df["open"].values[step_range] - self.df["close"].values[step_range] < 0
#         neg = self.df["open"].values[step_range] - self.df["close"].values[step_range] > 0
#
#         self.volume_ax.bar(
#             dates[pos],
#             volume[pos],
#             color=UP_COLOUR,
#             alpha=0.4,
#             width=1,
#             align="center"
#         )
#         self.volume_ax.bar(
#             dates[neg],
#             volume[neg],
#             color=DOWN_COLOUR,
#             alpha=0.4,
#             width=1,
#             align="center"
#         )
#
#         # Cap volume axis height below price chart and hide ticks
#         self.volume_ax.set_ylim(0, max(volume) / VOLUME_CHART_HEIGHT)
#         self.volume_ax.yaxis.set_ticks([])
#
#
#     def _render_trades(self, current_step, trades, step_range):
#         for trade in trades:
#             if trade["step"] in step_range:
#                 date = self.df["open_time"].values[trade["step"]]
#                 high = self.df["high"].values[trade["step"]]
#                 low = self.df["low"].values[trade["step"]]
#
#                 if trade["type"] == 'buy':
#                     high_low = low
#                     colour = UP_TEXT_COLOUR
#                 else:
#                     high_low = high
#                     colour = DOWN_TEXT_COLOUR
#
#                 total = f"{trade['total']:,.2f}"
#
#                 # Print the current price to the price axis
#                 self.price_ax.annotate(
#                     f"${total:,.2f}",
#                     (date, high_low),
#                     xytext=(date, high_low),
#                     color=colour,
#                     fontsize=8,
#                     arrowprops=(dict(color=colour))
#                 )
#
#     def render(self, current_step, net_worth, trades, window_size=40):
#         self.net_worths[current_step] = net_worth
#
#         window_start = max(current_step - window_size, 0)
#         step_range = range(window_start, current_step + 1)
#
#         dates = np.array([
#             datetime.utcfromtimestamp(x).strftime("%Y-%m-%d %H-%M") for x in self.df["open_time"].values[step_range]
#         ])
#
#         self._render_net_worth(current_step, net_worth, step_range, dates)
#         self._render_price(current_step, net_worth, dates, step_range)
#         self._render_volume(current_step, net_worth, dates, step_range)
#         self._render_trades(current_step, trades, step_range)
#
#         # Format the date ticks to be more easily read
#         self.price_ax.set_xticklabels(
#             self.df["open_time"].values[step_range],
#             rotation=45,
#             horizontalalignment="right"
#         )
#
#         # Hide duplicate net worth date labels
#         plt.setp(self.net_worth_ax.get_xticklabels(), visible=False)
#
#         # Necessary to view frames before they are unrendered
#         plt.pause(0.001)
#
#
#     def close(self):
#         plt.close()


## The gym environment

In [36]:
# import gym
#
# from gym import spaces
# from sklearn import preprocessing
#
#
#
# MAX_TRADING_SESSION = 100000  # ~2 months
#
#
#
# class CryptoTradingEnv(gym.Env):
#
#     metadata = {"render.modes": ["live", "file", "none"]}
#     scaler = preprocessing.MinMaxScaler()
#     viewer = None
#
#
#     def __init__(
#         self,
#         df: pd.DataFrame,
#         lookback_window_size : int=50,
#         commission : float = 0.00075,
#         initial_balance : float = 10_000.0,
#         serial : bool = False
#     ):
#         super(CryptoTradingEnv, self).__init__()
#
#         self.df = df.dropna().reset_index()
#         self.lookback_window_size=lookback_window_size
#         self.initial_balance = initial_balance
#         self.commission = commission
#         self.serial = serial
#         self.trades = []
#
#         # The agent can buy, sell, hold, at certain amounts 1/10 through 10/10
#         self.action_space = spaces.MultiDiscrete([3, 10])
#
#         #Observes the OHCLV values, net worth, and trade history
#         self.observation_space = spaces.Box(
#             shape=(10, lookback_window_size + 1),
#             low=0,
#             high=1
#         )
#
#
#     def reset(self):
#         # Reset the whole simulation
#         self.balance = self.initial_balance
#         self.net_worth = self.initial_balance
#         self.assets_held = 0
#
#         self._reset_session()
#
#         self.account_history = np.repeat([
#             [self.net_worth],
#             [0],
#             [0],
#             [0],
#             [0]
#         ],
#             self.lookback_window_size + 1,
#             axis=1
#         )
#
#         self.trades = []
#
#         return self._next_observation()
#
#
#     def _reset_session(self):
#         # I am not convinced on the "random traversal" approach here
#         # but there is some supporting evidence that it works, so I will humour it
#
#         self.current_step = 0
#
#         if self.serial:
#             self.steps_left = len(self.df) - self.lookback_window_size - 1
#             self.frame_start = self.lookback_window_size
#         else:
#             # Random traversal
#             self.steps_left = np.random.randint(1, MAX_TRADING_SESSION)
#             self.frame_start = np.random.randint(
#                 self.lookback_window_size,
#                 len(self.df) - self.steps_left
#             )
#
#         self.active_df = self.df[
#             self.frame_start - self.lookback_window_size : self.frame_start + self.steps_left
#         ]
#
#
#     def _next_observation(self):
#         # It is really important to ONLY scale the data that the model has seen,
#         # This is to prevent 'look-ahead bias'
#         end = self.current_step + self.lookback_window_size + 1
#
#         obs = np.array([
#             self.active_df["open"].values[self.current_step:end],
#             self.active_df["high"].values[self.current_step:end],
#             self.active_df["low"].values[self.current_step:end],
#             self.active_df["close"].values[self.current_step:end],
#             self.active_df["volume"].values[self.current_step:end],
#         ])
#
#         scaled_history = self.scaler.fit_transform(self.account_history)
#
#         obs = np.append(
#             obs,
#             scaled_history[:, -(self.lookback_window_size + 1):],
#             axis=0
#         )
#
#         return obs
#
#
#     def _get_current_price(self):
#         return self.df["close"].values[self.frame_start + self.current_step]
#
#
#     def step(self, action):
#         current_price = self._get_current_price() + 0.01
#         self._take_action(action, current_price)
#         self.steps_left -= 1
#         self.current_step += 1
#
#         if self.steps_left == 0:
#             self.balance += self.assets_held * current_price
#             self.assets_held = 0
#             self._reset_session()
#
#         obs = self._next_observation()
#         reward = self.net_worth
#         done = self.net_worth <= 0
#
#         return obs, reward, done, {}
#
#
#     def _take_action(self, action, current_price):
#         action_type = action[0]
#         amount = action[1] / 10
#
#         assets_bought = 0
#         assets_sold = 0
#         cost = 0
#         sales = 0
#
#         if action_type < 1:
#             # Trigger a buy
#             assets_bought = 0.0 if amount == 0 else self.balance / (current_price * amount)
#             cost = assets_bought * current_price * (1 + self.commission)
#             self.assets_held += assets_bought
#             self.balance -= cost
#
#         elif action_type < 2:
#             # Trigger a sell
#             assets_sold = self.assets_held * amount
#             sales = assets_sold * current_price * (1 - self.commission)
#             self.assets_held -= assets_sold
#             self.balance += sales
#
#         if assets_bought > 0 or assets_sold > 0:
#             self.trades.append({
#                 "step": self.frame_start + self.current_step,
#                 "amount": assets_bought if assets_bought > 0 else assets_sold,
#                 "total": cost if assets_bought > 0 else sales,
#                 "type": "buy" if assets_bought > 0 else "sell"
#             })
#
#         self.net_worth = self.balance + self.assets_held * current_price
#         self.account_history = np.append(
#             self.account_history,
#             [
#                 [self.net_worth],
#                 [assets_bought],
#                 [cost],
#                 [assets_sold],
#                 [sales]
#             ],
#             axis=1
#         )
#
#
#     def render(self, mode="human", **kwargs):
#         if mode == "human":
#             if self.viewer == None:
#                 self.viewer = CryptoTradingGraph(
#                     self.df,
#                     kwargs.get("title", None)
#                 )
#
#             self.viewer.render(
#                 self.frame_start + self.current_step,
#                 self.net_worth,
#                 self.trades,
#                 window_size=self.lookback_window_size
#             )


## Run it

In [37]:
# # The medium article was simple and just used OHCLV
# cols = [
#     "open",
#     "high",
#     "close",
#     "low",
#     "volume"
# ]
#
# df_gym = df[cols + ["open_time"]].copy()

In [38]:
# slice_point = int(len(df) - 100_000)
#
# train_df = df_gym[:slice_point]
# test_df = df_gym[slice_point:]

In [39]:
# train_env = CryptoTradingEnv(train_df, commission=0.00075, serial=False)
#
# test_env = CryptoTradingEnv(test_df, commission=0.00075, serial=True)

In [None]:
# from stable_baselines3 import A2C
#
# model = A2C(
#     "MlpPolicy",
#     train_env,
#     verbose=1,
#     tensorboard_log="./tensorboard/"
# )
#
# model.learn(total_timesteps=50_000)

In [None]:
# train_env.balance