# Imports

In [1]:
import pandas as pd
import pyarrow as pa
import numpy as np

from tqdm import tqdm
from datetime import datetime, timedelta

In [2]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:85% !important; }</style>"))

# Read Data

If you do not have the data then please run the *data-collector.ipynb* notebook and collect the specific crypto symbol you want to simulate.

In the future I will probably change this to a streamlit utility to make it much more user friendly but I am lazy and this is a prototype

In [3]:
symbol = "BTCUSDT"

df = pd.read_parquet(f"../data_binance_crypto/symbol={symbol}/", engine="pyarrow")
df.shape

(349290, 13)

In [4]:
df.head()

Unnamed: 0,open_time,open,high,low,close,volume,close_time,quote_asset_vol,num_trades,taker_buy_base_asset_vol,taker_buy_quote_asset_vol,ignore_this,day
0,1609459200000,28923.63,28961.66,28913.12,28961.66,27.457032,1609459259999,794382.04398665,1292,16.777195,485390.8268246,0,2021-01-01
1,1609459260000,28961.67,29017.5,28961.01,29009.91,58.477501,1609459319999,1695802.89696884,1651,33.733818,978176.46820208,0,2021-01-01
2,1609459320000,29009.54,29016.71,28973.58,28989.3,42.470329,1609459379999,1231358.69059884,986,13.247444,384076.85445305,0,2021-01-01
3,1609459380000,28989.68,28999.85,28972.33,28982.69,30.360677,1609459439999,880016.76348383,959,9.456028,274083.07514154,0,2021-01-01
4,1609459440000,28982.67,28995.93,28971.8,28975.65,24.124339,1609459499999,699226.20560386,726,6.814644,197519.37488805,0,2021-01-01


## Quick and simple EDA

In [5]:
df.isna().sum()

open_time                    0
open                         0
high                         0
low                          0
close                        0
volume                       0
close_time                   0
quote_asset_vol              0
num_trades                   0
taker_buy_base_asset_vol     0
taker_buy_quote_asset_vol    0
ignore_this                  0
day                          0
dtype: int64

In [6]:
columns = df.columns
for col in sorted(columns):
    print(f"Col: {col.ljust(40, ' '):40} Type: {df[col].dtype}")

Col: close                                    Type: object
Col: close_time                               Type: int64
Col: day                                      Type: category
Col: high                                     Type: object
Col: ignore_this                              Type: object
Col: low                                      Type: object
Col: num_trades                               Type: int64
Col: open                                     Type: object
Col: open_time                                Type: int64
Col: quote_asset_vol                          Type: object
Col: taker_buy_base_asset_vol                 Type: object
Col: taker_buy_quote_asset_vol                Type: object
Col: volume                                   Type: object


In [7]:
float_columns = [
    "open",
    "high",
    "low",
    "close",
    "volume",
    "quote_asset_vol",
    "taker_buy_base_asset_vol",
    "taker_buy_quote_asset_vol",
    "ignore_this",
]

for col in float_columns:
    df[col] = df[col].astype(float)

In [8]:
df.describe()

Unnamed: 0,open_time,open,high,low,close,volume,close_time,quote_asset_vol,num_trades,taker_buy_base_asset_vol,taker_buy_quote_asset_vol,ignore_this
count,349290.0,349290.0,349290.0,349290.0,349290.0,349290.0,349290.0,349290.0,349290.0,349290.0,349290.0,349290.0
mean,1619954000000.0,44291.498376,44329.90394,44253.400252,44291.562309,57.039721,1619954000000.0,2450697.0,1420.229265,28.079218,1206923.0,0.0
std,6061868000.0,9600.903307,9602.704045,9598.646857,9600.898419,70.018457,6061868000.0,2843840.0,1203.971846,36.765979,1508756.0,0.0
min,1609459000000.0,28241.95,28764.23,28130.0,28235.47,0.0,1609459000000.0,0.0,0.0,0.0,0.0,0.0
25%,1614700000000.0,35398.9975,35436.6325,35360.0,35399.0,23.10222,1614700000000.0,1012094.0,776.0,10.716988,468984.3,0.0
50%,1619967000000.0,43857.955,43915.2,43800.025,43858.66,37.500648,1619967000000.0,1668570.0,1136.0,17.937238,797938.9,0.0
75%,1625202000000.0,53726.9775,53778.4825,53676.9375,53726.97,64.550522,1625203000000.0,2831762.0,1677.0,31.669045,1389989.0,0.0
max,1630454000000.0,64800.0,64854.0,64685.17,64800.0,2636.713888,1630454000000.0,113686300.0,42282.0,2014.965612,89475510.0,0.0


# Feature engineering

Initially this is going to be a bunch of *traditionally useful* financial features. In the paraphrased words of our boy Ernest Chan, "throw heaps of features at your models during prototyping, let the model decide what it thinks is important, it is probably smarter than you."

^This is obviously a joke taken to extremes but you get the point. Who are we to decide what is a "good feature", if we want to find some form of market inefficieny then why not try everything and use feature selection to trim down our list for us.

According to this, https://arxiv.org/abs/2005.12483, LIME is just as promising as SHAP so I am eager to explore this. If it fails we can always fall back on good old SHAP

***Disclaimer***: To be frank, the choice for the rolling window size was totally arbitrary

In [9]:
rolling_window_size = 5

In [10]:
df.head()

Unnamed: 0,open_time,open,high,low,close,volume,close_time,quote_asset_vol,num_trades,taker_buy_base_asset_vol,taker_buy_quote_asset_vol,ignore_this,day
0,1609459200000,28923.63,28961.66,28913.12,28961.66,27.457032,1609459259999,794382.0,1292,16.777195,485390.826825,0.0,2021-01-01
1,1609459260000,28961.67,29017.5,28961.01,29009.91,58.477501,1609459319999,1695803.0,1651,33.733818,978176.468202,0.0,2021-01-01
2,1609459320000,29009.54,29016.71,28973.58,28989.3,42.470329,1609459379999,1231359.0,986,13.247444,384076.854453,0.0,2021-01-01
3,1609459380000,28989.68,28999.85,28972.33,28982.69,30.360677,1609459439999,880016.8,959,9.456028,274083.075142,0.0,2021-01-01
4,1609459440000,28982.67,28995.93,28971.8,28975.65,24.124339,1609459499999,699226.2,726,6.814644,197519.374888,0.0,2021-01-01


## Sell Conditions
The "taker buy" columns come from when an order is filled based on an existing limit sell

This number will always be less than the raw traded volume. This is because the raw volume will include market price sells

We can therefore assume that when we take a ratio it will be a number between zero and one and represent the number of intentional sales. The inverse of this number will be the trades that come from panic selling at market or other market forces like liquidation

In [11]:
df["ratio_intentional_trades"] = df["taker_buy_base_asset_vol"] / df["volume"]

In [12]:
df["ratio_intentional_trades_ma"] = df.rolling(rolling_window_size)["ratio_intentional_trades"].mean()

In [13]:
df.head()

Unnamed: 0,open_time,open,high,low,close,volume,close_time,quote_asset_vol,num_trades,taker_buy_base_asset_vol,taker_buy_quote_asset_vol,ignore_this,day,ratio_intentional_trades,ratio_intentional_trades_ma
0,1609459200000,28923.63,28961.66,28913.12,28961.66,27.457032,1609459259999,794382.0,1292,16.777195,485390.826825,0.0,2021-01-01,0.611035,
1,1609459260000,28961.67,29017.5,28961.01,29009.91,58.477501,1609459319999,1695803.0,1651,33.733818,978176.468202,0.0,2021-01-01,0.576868,
2,1609459320000,29009.54,29016.71,28973.58,28989.3,42.470329,1609459379999,1231359.0,986,13.247444,384076.854453,0.0,2021-01-01,0.311922,
3,1609459380000,28989.68,28999.85,28972.33,28982.69,30.360677,1609459439999,880016.8,959,9.456028,274083.075142,0.0,2021-01-01,0.311456,
4,1609459440000,28982.67,28995.93,28971.8,28975.65,24.124339,1609459499999,699226.2,726,6.814644,197519.374888,0.0,2021-01-01,0.28248,0.418752


## Datetime semantics

### Year based
If assuming cyclical behaviour then we would need more cases of each cycle.

Because we have only collected ~8 months of training data we will have less than one cycle making this yearly data almost useless

In [14]:
# df["quarter"] = pd.to_datetime(df["day"]).dt.quarter

In [15]:
# df["month_of_year"] = pd.to_datetime(df["day"]).dt.month

In [16]:
# df["week_of_year"] = pd.to_datetime(df["day"]).dt.isocalendar().week

### Month based

Month based data is assuming "Fund Flow" interactions where you have large firms re-balancing on a monthly schedule

This might not be part of crypto but could prove valuable anyways

In [17]:
df["day_of_month"] = pd.to_datetime(df["day"]).dt.day

In [18]:
df["day_of_week"] = pd.to_datetime(df["day"]).dt.dayofweek

### Day based

Day based data will show the influence each region has on the price. 

Because crypto is traded 24/7 it doesn't have a traditional market open and close. However each region around the world still needs to sleep and work so you will end up getting cyclical regional activity.

Each region has different risks and economical situations so you should get different behaviours.

You might also catch algorithmic trading based on certain times of day

In [19]:
df["hour_of_day"] = pd.to_datetime(df["open_time"], unit="ms").dt.hour

In [20]:
df["minute_of_hour"] = pd.to_datetime(df["open_time"], unit="ms").dt.minute

In [21]:
df.head()

Unnamed: 0,open_time,open,high,low,close,volume,close_time,quote_asset_vol,num_trades,taker_buy_base_asset_vol,taker_buy_quote_asset_vol,ignore_this,day,ratio_intentional_trades,ratio_intentional_trades_ma,day_of_month,day_of_week,hour_of_day,minute_of_hour
0,1609459200000,28923.63,28961.66,28913.12,28961.66,27.457032,1609459259999,794382.0,1292,16.777195,485390.826825,0.0,2021-01-01,0.611035,,1,4,0,0
1,1609459260000,28961.67,29017.5,28961.01,29009.91,58.477501,1609459319999,1695803.0,1651,33.733818,978176.468202,0.0,2021-01-01,0.576868,,1,4,0,1
2,1609459320000,29009.54,29016.71,28973.58,28989.3,42.470329,1609459379999,1231359.0,986,13.247444,384076.854453,0.0,2021-01-01,0.311922,,1,4,0,2
3,1609459380000,28989.68,28999.85,28972.33,28982.69,30.360677,1609459439999,880016.8,959,9.456028,274083.075142,0.0,2021-01-01,0.311456,,1,4,0,3
4,1609459440000,28982.67,28995.93,28971.8,28975.65,24.124339,1609459499999,699226.2,726,6.814644,197519.374888,0.0,2021-01-01,0.28248,0.418752,1,4,0,4


## Make trade values stationary

For timeseries data there is a concept called "stationarity", in short this is the act of transforming the data so that statistical properties such as mean, variance, autocorrelation, etc actually mean something for the future. 

Commonly this is done by converting changes in price to percentages and scaling them across your train data. The model can only act on information it has seen before, by making your data stationary it is possible that it can pick out recurring behaviour.

### Price change

Here we have the raw price change and the price change as a percentage of the opening

We also want to look at a "smoother" price change in the form of a moving average. This is because typically when financial data is trending it isn't a nice line, it is a "spiky" line. By taking the moving average we lose some granularity but we can see the underlying momentum trend

In [22]:
df["price_change"] = df["close"] - df["open"]
df["price_change_perc"] =  df["price_change"] / df["open"]

In [23]:
df["price_change_ma"] = df.rolling(rolling_window_size)["price_change"].mean()
df["price_change_perc_smooth"] = df["price_change_ma"] / df["open"]

In [24]:
df[[
    "open",
    "close",
    "price_change",
    "price_change_perc",
    "price_change_ma",
    "price_change_perc_smooth"
]].head(20)

Unnamed: 0,open,close,price_change,price_change_perc,price_change_ma,price_change_perc_smooth
0,28923.63,28961.66,38.03,0.001315,,
1,28961.67,29009.91,48.24,0.001666,,
2,29009.54,28989.3,-20.24,-0.000698,,
3,28989.68,28982.69,-6.99,-0.000241,,
4,28982.67,28975.65,-7.02,-0.000242,10.404,0.000359
5,28975.65,28937.11,-38.54,-0.00133,-4.91,-0.000169
6,28937.11,28943.87,6.76,0.000234,-13.206,-0.000456
7,28943.88,28934.84,-9.04,-0.000312,-10.966,-0.000379
8,28934.84,28900.0,-34.84,-0.001204,-16.536,-0.000571
9,28900.0,28858.94,-41.06,-0.001421,-23.344,-0.000808


### Volatility

Variance or "volatility" can be used to determine how stable the trading period was.

Again we are using a smooth apporach to see if it helps the model.

In [25]:
df["volatility"] = df["high"] - df["low"]
df["volatility_perc"] = df["volatility"] / df["open"]

In [26]:
df["volatility_ma"] = df.rolling(rolling_window_size)["volatility"].mean()
df["volatility_perc_smooth"] = df["volatility_ma"] / df["open"]

In [27]:
df[[
    "high",
    "low",
    "volatility",
    "volatility_perc",
    "volatility_ma",
    "volatility_perc_smooth"
]].head(20)

Unnamed: 0,high,low,volatility,volatility_perc,volatility_ma,volatility_perc_smooth
0,28961.66,28913.12,48.54,0.001678,,
1,29017.5,28961.01,56.49,0.001951,,
2,29016.71,28973.58,43.13,0.001487,,
3,28999.85,28972.33,27.52,0.000949,,
4,28995.93,28971.8,24.13,0.000833,39.962,0.001379
5,28979.53,28933.16,46.37,0.0016,39.528,0.001364
6,28963.25,28937.1,26.15,0.000904,33.46,0.001156
7,28954.48,28930.0,24.48,0.000846,29.73,0.001027
8,28936.15,28889.24,46.91,0.001621,33.608,0.001162
9,28920.06,28846.28,73.78,0.002553,43.538,0.001507


### Volume change

Volume is another indicator of stability. If the volume suddenly jumps above the norm we might expect a change in market direction. 

For this reason we are looking at the volume change since last time period and another smoothed version

In [28]:
df["last_volume"] = df["volume"].shift(1)

df["volume_change"] = df["volume"] - df["last_volume"]
df["volume_change_perc"] = df["volume_change"] / df["last_volume"]

In [29]:
df["volume_change_ma"] = df.rolling(rolling_window_size)["volume_change"].mean()
df["volume_change_perc_smooth"] = df["volume_change_ma"] / df["last_volume"]

In [30]:
df[[
    "volume",
    "last_volume",
    "volume_change",
    "volume_change_perc",
    "volume_change_ma",
    "volume_change_perc_smooth"
]].head(20)

Unnamed: 0,volume,last_volume,volume_change,volume_change_perc,volume_change_ma,volume_change_perc_smooth
0,27.457032,,,,,
1,58.477501,27.457032,31.020469,1.129782,,
2,42.470329,58.477501,-16.007172,-0.273732,,
3,30.360677,42.470329,-12.109652,-0.285132,,
4,24.124339,30.360677,-6.236338,-0.205408,,
5,22.396014,24.124339,-1.728325,-0.071642,-1.012204,-0.041958
6,20.480294,22.396014,-1.91572,-0.085538,-7.599441,-0.339321
7,20.962343,20.480294,0.482049,0.023537,-4.301597,-0.210036
8,52.645478,20.962343,31.683135,1.511431,4.45696,0.212617
9,98.083975,52.645478,45.438497,0.863104,14.791927,0.280972


## Graphs

I am eager to get on to the model so will leave the feature engineering here for now, we can always return and add more to it

In [None]:
import plotly.express as px

fig = px.line(
    df.tail(2000),
    x="open_time",
    y=["price_change_perc", "price_change_perc_smooth"],
    height=1024
)

fig.update_traces(
    marker=dict(
        size=12,
        opacity=0.7,
        line=dict(
            width=2,
            color='DarkSlateGrey'
        )
    ),
    selector=dict(mode='markers')
)

In [None]:
import plotly.express as px

px.scatter(
    df.tail(2000),
    x="open_time",
    y="open",
    height=1024
)

# The reinforcement learning

## Graph visualisation

In [31]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from matplotlib import style

# finance module is no longer part of matplotlib
# see: https://github.com/matplotlib/mpl_finance
import mplfinance as mpf



style.use('dark_background')

VOLUME_CHART_HEIGHT = 0.33

UP_COLOUR = '#27A59A'
DOWN_COLOUR = '#EF534F'
UP_TEXT_COLOUR = '#73D3CC'
DOWN_TEXT_COLOUR = '#DC2C27'



class CryptoTradingGraph:
    """A crypto trading visualization using matplotlib made to render OpenAI gym environments"""

    
    def __init__(self, df, title=None):
        self.df = df
        self.net_worths = np.zeros(len(df["open_time"]))

        # Create a figure on screen and set the title
        fig = plt.figure()
        fig.suptitle(title)

        # Create top subplot for net worth axis
        self.net_worth_ax = plt.subplot2grid(
            (6, 1), 
            (0, 0), 
            rowspan=2, 
            colspan=1
        )

        # Create bottom subplot for shared price/volume axis
        self.price_ax = plt.subplot2grid(
            (6, 1), 
            (2, 0), 
            rowspan=8, 
            colspan=1, 
            sharex=self.net_worth_ax
        )

        # Create a new axis for volume which shares its x-axis with price
        self.volume_ax = self.price_ax.twinx()

        # Add padding to make graph easier to view
        plt.subplots_adjust(
            left=0.11, 
            bottom=0.24,
            right=0.90, 
            top=0.90, 
            wspace=0.2, 
            hspace=0
        )

        # Show the graph without blocking the rest of the program
        plt.show(block=False)
    

    def _render_net_worth(self, current_step, net_worth, step_range, dates):
        # Clear the frame rendered last step
        self.net_worth_ax.clear()

        # Plot net worths
        self.net_worth_ax.plot_date(
            dates, 
            self.net_worths[step_range], 
            '-', 
            label="Net Worth"
        )

        # Show legend, which uses the label we defined for the plot above
        self.net_worth_ax.legend()
        legend = self.net_worth_ax.legend(loc=2, ncol=2, prop={"size": 8})
        legend.get_frame().set_alpha(0.4)

        last_date = self.df["open_time"].values[current_step]
        last_net_worth = self.net_worths[current_step]

        # Annotate the current net worth on the net worth graph
        self.net_worth_ax.annotate(
            f"{net_worth:,.2f}", 
            (last_date, last_net_worth),
            xytext=(last_date, last_net_worth),
            bbox=dict(boxstyle="round", fc='w', ec='k', lw=1),
            color="black",
            fontsize="small"
        )

        # Add space above and below min/max net worth
        self.net_worth_ax.set_ylim(
            min(self.net_worths[np.nonzero(self.net_worths)]) / 1.25, 
            max(self.net_worths) * 1.25
        )

    
    def _render_price(self, current_step, net_worth, dates, step_range):
        self.price_ax.clear()

        candlesticks = zip(
            dates,
            self.df["open"].values[step_range], 
            self.df["close"].values[step_range],
            self.df["high"].values[step_range], 
            self.df["low"].values[step_range]
        )

        # Plot price using candlestick graph from mpl_finance
        mpf.plot(
            self.price_ax, 
            candlesticks, 
            width=1,
            colorup=UP_COLOUR, 
            colordown=DOWN_COLOUR,
            type="candle"
        )

        last_date = self.df["open_time"].values[current_step]
        last_close = self.df["close"].values[current_step]
        last_high = self.df["high"].values[current_step]

        # Print the current price to the price axis
        self.price_ax.annotate(
            f"{last_close}:,.2f", 
            (last_date, last_close),
            xytext=(last_date, last_high),
            bbox=dict(boxstyle='round', fc='w', ec='k', lw=1),
            color="black",
            fontsize="small"
        )

        # Shift price axis up to give volume chart space
        ylim = self.price_ax.get_ylim()
        self.price_ax.set_ylim(
            ylim[0] - (ylim[1] - ylim[0]) * VOLUME_CHART_HEIGHT, 
            ylim[1]
        )

    
    def _render_volume(self, current_step, net_worth, dates, step_range):
        self.volume_ax.clear()

        volume = np.array(self.df["volume"].values[step_range])

        pos = self.df["open"].values[step_range] - self.df["close"].values[step_range] < 0
        neg = self.df["open"].values[step_range] - self.df["close"].values[step_range] > 0

        self.volume_ax.bar(
            dates[pos], 
            volume[pos], 
            color=UP_COLOUR,
            alpha=0.4, 
            width=1, 
            align="center"
        )
        self.volume_ax.bar(
            dates[neg], 
            volume[neg], 
            color=DOWN_COLOUR,
            alpha=0.4, 
            width=1, 
            align="center"
        )

        # Cap volume axis height below price chart and hide ticks
        self.volume_ax.set_ylim(0, max(volume) / VOLUME_CHART_HEIGHT)
        self.volume_ax.yaxis.set_ticks([])

    
    def _render_trades(self, current_step, trades, step_range):
        for trade in trades:
            if trade["step"] in step_range:
                date = self.df["open_time"].values[trade["step"]]
                high = self.df["high"].values[trade["step"]]
                low = self.df["low"].values[trade["step"]]

                if trade["type"] == 'buy':
                    high_low = low
                    colour = UP_TEXT_COLOUR
                else:
                    high_low = high
                    colour = DOWN_TEXT_COLOUR

                total = f"{trade['total']:,.2f}"

                # Print the current price to the price axis
                self.price_ax.annotate(
                    f"${total:,.2f}", 
                    (date, high_low),
                    xytext=(date, high_low),
                    color=colour,
                    fontsize=8,
                    arrowprops=(dict(color=colour))
                )

    def render(self, current_step, net_worth, trades, window_size=40):
        self.net_worths[current_step] = net_worth

        window_start = max(current_step - window_size, 0)
        step_range = range(window_start, current_step + 1)

        dates = np.array([
            datetime.utcfromtimestamp(x).strftime("%Y-%m-%d %H-%M") for x in self.df["open_time"].values[step_range]
        ])

        self._render_net_worth(current_step, net_worth, step_range, dates)
        self._render_price(current_step, net_worth, dates, step_range)
        self._render_volume(current_step, net_worth, dates, step_range)
        self._render_trades(current_step, trades, step_range)

        # Format the date ticks to be more easily read
        self.price_ax.set_xticklabels(
            self.df["open_time"].values[step_range], 
            rotation=45,
            horizontalalignment="right"
        )

        # Hide duplicate net worth date labels
        plt.setp(self.net_worth_ax.get_xticklabels(), visible=False)

        # Necessary to view frames before they are unrendered
        plt.pause(0.001)

    
    def close(self):
        plt.close()


## The gym environment

In [32]:
import gym

from gym import spaces
from sklearn import preprocessing



MAX_TRADING_SESSION = 100000  # ~2 months



class CryptoTradingEnv(gym.Env):
    
    metadata = {"render.modes": ["live", "file", "none"]}
    scaler = preprocessing.MinMaxScaler()
    viewer = None
    
    
    def __init__(
        self, 
        df: pd.DataFrame, 
        lookback_window_size : int=50,
        commission : float = 0.00075,
        initial_balance : float = 10_000.0,
        serial : bool = False
    ):
        super(CryptoTradingEnv, self).__init__()
        
        self.df = df.dropna().reset_index()
        self.lookback_window_size=lookback_window_size
        self.initial_balance = initial_balance
        self.commission = commission
        self.serial = serial
        self.trades = []
        
        # The agent can buy, sell, hold, at certain amounts 1/10 through 10/10
        self.action_space = spaces.MultiDiscrete([3, 10])
        
        #Observes the OHCLV values, net worth, and trade history
        self.observation_space = spaces.Box(
            shape=(10, lookback_window_size + 1),
            low=0,
            high=1
        )
    
    
    def reset(self):
        # Reset the whole simulation
        self.balance = self.initial_balance
        self.net_worth = self.initial_balance
        self.assets_held = 0
        
        self._reset_session()
        
        self.account_history = np.repeat([
            [self.net_worth],
            [0],
            [0],
            [0],
            [0]
        ], 
            self.lookback_window_size + 1, 
            axis=1
        )
        
        self.trades = []
        
        return self._next_observation()
    
    
    def _reset_session(self):
        # I am not convinced on the "random traversal" approach here 
        # but there is some supporting evidence that it works, so I will humour it
        
        self.current_step = 0
        
        if self.serial:
            self.steps_left = len(self.df) - self.lookback_window_size - 1
            self.frame_start = self.lookback_window_size
        else:
            # Random traversal
            self.steps_left = np.random.randint(1, MAX_TRADING_SESSION)
            self.frame_start = np.random.randint(
                self.lookback_window_size, 
                len(self.df) - self.steps_left
            )

        self.active_df = self.df[
            self.frame_start - self.lookback_window_size : self.frame_start + self.steps_left
        ]
            
    
    def _next_observation(self):
        # It is really important to ONLY scale the data that the model has seen,
        # This is to prevent 'look-ahead bias'
        end = self.current_step + self.lookback_window_size + 1
        
        obs = np.array([
            self.active_df["open"].values[self.current_step:end],  
            self.active_df["high"].values[self.current_step:end],
            self.active_df["low"].values[self.current_step:end],
            self.active_df["close"].values[self.current_step:end],
            self.active_df["volume"].values[self.current_step:end],
        ])
        
        scaled_history = self.scaler.fit_transform(self.account_history)
        
        obs = np.append(
            obs, 
            scaled_history[:, -(self.lookback_window_size + 1):], 
            axis=0
        )
        
        return obs
    
    
    def _get_current_price(self):
        return self.df["close"].values[self.frame_start + self.current_step]
    
    
    def step(self, action):
        current_price = self._get_current_price() + 0.01
        self._take_action(action, current_price)
        self.steps_left -= 1
        self.current_step += 1
        
        if self.steps_left == 0:
            self.balance += self.assets_held * current_price
            self.assets_held = 0
            self._reset_session()
            
        obs = self._next_observation()
        reward = self.net_worth
        done = self.net_worth <= 0
        
        return obs, reward, done, {}
    
    
    def _take_action(self, action, current_price):
        action_type = action[0]
        amount = action[1] / 10
        
        assets_bought = 0
        assets_sold = 0
        cost = 0
        sales = 0
        
        if action_type < 1:
            # Trigger a buy
            assets_bought = 0.0 if amount == 0 else self.balance / (current_price * amount)
            cost = assets_bought * current_price * (1 + self.commission)
            self.assets_held += assets_bought
            self.balance -= cost
        
        elif action_type < 2:
            # Trigger a sell
            assets_sold = self.assets_held * amount
            sales = assets_sold * current_price * (1 - self.commission)
            self.assets_held -= assets_sold
            self.balance += sales
            
        if assets_bought > 0 or assets_sold > 0:
            self.trades.append({
                "step": self.frame_start + self.current_step,
                "amount": assets_bought if assets_bought > 0 else assets_sold,
                "total": cost if assets_bought > 0 else sales,
                "type": "buy" if assets_bought > 0 else "sell"
            })
            
        self.net_worth = self.balance + self.assets_held * current_price
        self.account_history = np.append(
            self.account_history,
            [
                [self.net_worth],
                [assets_bought],
                [cost],
                [assets_sold],
                [sales]
            ],
            axis=1
        )
    
    
    def render(self, mode="human", **kwargs):
        if mode == "human":
            if self.viewer == None:
                self.viewer = CryptoTradingGraph(
                    self.df,
                    kwargs.get("title", None)
                )
            
            self.viewer.render(
                self.frame_start + self.current_step,
                self.net_worth,
                self.trades,
                window_size=self.lookback_window_size
            )


## Run it

In [35]:
# The medium article was simple and just used OHCLV
cols = [
    "open",
    "high",
    "close",
    "low",
    "volume"
]

df_gym = df[cols + ["open_time"]].copy()

In [36]:
slice_point = int(len(df) - 100_000)

train_df = df_gym[:slice_point]
test_df = df_gym[slice_point:]

In [37]:
train_env = CryptoTradingEnv(train_df, commission=0.00075, serial=False)

test_env = CryptoTradingEnv(test_df, commission=0.00075, serial=True)

In [38]:
from stable_baselines3 import A2C

model = A2C(
    "MlpPolicy",
    train_env,
    verbose=1,
    tensorboard_log="./tensorboard/"
)

model.learn(total_timesteps=50_000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./tensorboard/A2C_9
------------------------------------
| time/                 |          |
|    fps                | 177      |
|    iterations         | 100      |
|    time_elapsed       | 2        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -2.89    |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 8.03e+04 |
|    value_loss         | 9.05e+08 |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 252      |
|    iterations         | 200      |
|    time_elapsed       | 3        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -2.52    |
|    explained_variance | 0        |
|    learning_rate      

<stable_baselines3.a2c.a2c.A2C at 0x7f630cf33d00>

In [40]:
train_env.balance

5439.5463607221645