<a href="https://colab.research.google.com/github/NicoleRichards1998/FinRL/blob/master/DayTradingEnv_%20Testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [82]:
import gym
import numpy as np
from numpy import random as rd

class StockTradingEnv(gym.Env):
    def __init__(
        self,
        config,
        initial_account=1e6,
        gamma=0.99,
        turbulence_thresh=99,
        min_stock_rate=0.1,
        max_stock=1e2,
        initial_capital=1e6,
        buy_cost_pct=1e-3,
        sell_cost_pct=1e-3,
        reward_scaling=2 ** -11,
        initial_stocks=None,
    ):
        price_ary = config["price_array"]
        #tech_ary = config["tech_array"]
        turbulence_ary = config["turbulence_array"]
        if_train = config["if_train"]
        self.price_ary = price_ary.astype(np.float32)
        #self.tech_ary = tech_ary.astype(np.float32)
        self.turbulence_ary = turbulence_ary

        #self.tech_ary = self.tech_ary * 2 ** -7
        self.turbulence_bool = (turbulence_ary > turbulence_thresh).astype(np.float32)
        self.turbulence_ary = (
            self.sigmoid_sign(turbulence_ary, turbulence_thresh) * 2 ** -5
        ).astype(np.float32)

        stock_dim = self.price_ary.shape[1]
        self.gamma = gamma
        self.max_stock = max_stock
        self.min_stock_rate = min_stock_rate
        self.buy_cost_pct = buy_cost_pct
        self.sell_cost_pct = sell_cost_pct
        self.reward_scaling = reward_scaling
        self.initial_capital = initial_capital
        self.initial_stocks = (
            np.zeros(stock_dim, dtype=np.float32)
            if initial_stocks is None
            else initial_stocks
        )

        # reset()
        self.day = None
        self.amount = None
        self.stocks = None
        self.total_asset = None
        self.gamma_reward = None
        self.initial_total_asset = None

        # environment information
        self.env_name = "StockEnv"
        # self.state_dim = 1 + 2 + 2 * stock_dim + self.tech_ary.shape[1]
        # # amount + (turbulence, turbulence_bool) + (price, stock) * stock_dim + tech_dim
        self.state_dim = 1 + 2 + 3 * stock_dim #+ self.tech_ary.shape[1]
        # amount + (turbulence, turbulence_bool) + (price, stock) * stock_dim + tech_dim
        self.stocks_cd = None
        self.action_dim = stock_dim
        self.max_step = self.price_ary.shape[0] - 1
        self.if_train = if_train
        self.if_discrete = False
        self.target_return = 10.0
        self.episode_return = 0.0

        self.observation_space = gym.spaces.Box(
            low=-3000, high=3000, shape=(self.state_dim,), dtype=np.float32
        )
        self.action_space = gym.spaces.Box(
            low=-1, high=1, shape=(self.action_dim,), dtype=np.float32
        )

    def reset(self):
        self.day = 0
        price = self.price_ary[self.day]

        if self.if_train:
            self.stocks = (
                self.initial_stocks + rd.randint(0, 64, size=self.initial_stocks.shape)
            ).astype(np.float32)
            self.stocks_cool_down = np.zeros_like(self.stocks)
            self.amount = (
                self.initial_capital * rd.uniform(0.95, 1.05)
                - (self.stocks * price).sum()
            )
        else:
            self.stocks = self.initial_stocks.astype(np.float32)
            self.stocks_cool_down = np.zeros_like(self.stocks)
            self.amount = self.initial_capital

        self.total_asset = self.amount + (self.stocks * price).sum()
        self.initial_total_asset = self.total_asset
        self.gamma_reward = 0.0
        return self.get_state(price)  # state

    def step(self, actions):
        actions = (actions * self.max_stock).astype(int)

        self.day += 1
        price = self.price_ary[self.day]
        self.stocks_cool_down += 1

        if self.turbulence_bool[self.day] == 0:
            min_action = int(self.max_stock * self.min_stock_rate)  # stock_cd
            for index in np.where(actions < -min_action)[0]:  # sell_index:
                if price[index] > 0:  # Sell only if current asset is > 0
                    sell_num_shares = min(self.stocks[index], -actions[index])
                    self.stocks[index] -= sell_num_shares
                    self.amount += (
                        price[index] * sell_num_shares * (1 - self.sell_cost_pct)
                    )
                    self.stocks_cool_down[index] = 0
            for index in np.where(actions > min_action)[0]:  # buy_index:
                if (
                    price[index] > 0
                ):  # Buy only if the price is > 0 (no missing data in this particular date)
                    buy_num_shares = min(self.amount // price[index], actions[index])
                    self.stocks[index] += buy_num_shares
                    self.amount -= (
                        price[index] * buy_num_shares * (1 + self.buy_cost_pct)
                    )
                    self.stocks_cool_down[index] = 0

        else:  # sell all when turbulence
            self.amount += (self.stocks * price).sum() * (1 - self.sell_cost_pct)
            self.stocks[:] = 0
            self.stocks_cool_down[:] = 0

        state = self.get_state(price)
        total_asset = self.amount + (self.stocks * price).sum()
        #print(total_asset)
        reward = (total_asset - self.total_asset) * self.reward_scaling
        #print(reward)
        self.total_asset = total_asset

        self.gamma_reward = self.gamma_reward * self.gamma + reward
        done = self.day == self.max_step
        if done:
            reward = self.gamma_reward
            self.episode_return = total_asset / self.initial_total_asset

        return state, reward, done, dict()

    def get_state(self, price):
        amount = np.array(self.amount * (2 ** -12), dtype=np.float32)
        scale = np.array(2 ** -6, dtype=np.float32)
        return np.hstack(
            (
                amount,
                self.turbulence_ary[self.day],
                self.turbulence_bool[self.day],
                price * scale,
                self.stocks * scale,
                self.stocks_cool_down,
                #self.tech_ary[self.day],
            )
        )  # state.astype(np.float32)

    @staticmethod
    def sigmoid_sign(ary, thresh):
        def sigmoid(x):
            return 1 / (1 + np.exp(-x * np.e)) - 0.5

        return sigmoid(ary / thresh) * thresh

In [None]:
!pip uninstall -y pyarrow
!pip uninstall -y ray # clean removal of previous install, otherwise version number may cause pip not to upgrade
!pip install tf-estimator-nightly==2.8.0.dev2021122109
!pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl # minimal install
!pip install lz4
# A hack to force the runtime to restart, needed to include the above dependencies.
print("Done installing! Restarting via forced crash (this is not an issue).")
import os
os._exit(0)

Found existing installation: pyarrow 6.0.1
Uninstalling pyarrow-6.0.1:
  Successfully uninstalled pyarrow-6.0.1
Collecting tf-estimator-nightly==2.8.0.dev2021122109
  Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)
[K     |████████████████████████████████| 462 kB 5.5 MB/s 
[?25hInstalling collected packages: tf-estimator-nightly
Successfully installed tf-estimator-nightly-2.8.0.dev2021122109
Collecting ray==2.0.0.dev0
  Downloading https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl (53.6 MB)
[K     |████████████████████████████████| 53.6 MB 1.3 MB/s 
Collecting aiosignal
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting frozenlist
  Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
[K     |████████████████████████████████| 144 kB 5.3 MB/s 
[?25hCollecting virtualenv
  Downloading virtualenv-20.

In [24]:
import pandas as pd
import ray
from ray.rllib.agents.ppo.ppo import PPOTrainer
import array as arr

In [2]:
JSEIndexes = [ 'ACL' ]
ticker_list = JSEIndexes
action_dim = len(ticker_list)

INDICATORS = ['macd', 'boll_ub', 'boll_lb', 'rsi_30', 'dx_30', 'close_30_sma', 'close_60_sma']
tech_indicator_list = INDICATORS

Number_Train_Days = 15

In [74]:
turbulence_array = np.array([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,22.1084222,9.67023344,41.1693169,19.1532296,13.7609015,16.1861931,19.8504372,20.5604525,25.4778107,22.997469,24.344648,5.8348802,20.4890941,39.4762356,35.2420246,38.073237,37.3956777,16.0432548,27.2026388,74.9701996,36.6446395,17.2356173,8.26889763,35.0190994,16.0213982,8.25901877,13.068116,19.0129904,35.8202011,7.51760058,13.2072048,16.3513083,21.0994217,8.55916103,20.1677023,16.4633453,28.4869495,209.912747,9.35136978,10.4790715,12.6530392,20.3009207,22.9551434,22.2838805,7.69247255,9.64157207,52.9446413,12.4934834,83.0413624,41.2709711,11.4005592,16.8718384,49.8116725,16.4329166,18.0112372,12.7541583,53.6965149,39.5256077,42.3696768,8.38575579,17.6195145,8.49749731,10.6554826,327.555984,13.404072,38.6838739,6.89542406,50.8017775,15.875114,28.166164,6.61625544,14.7403782,21.6496009,75.2799184,28.65897,29.5858385,80.866183,15.6118035,150.627846,41.3192567,17.3458798,41.4452621,17.3350963,34.435636,24.4467436,19.7274451,40.4424791,12.5573777,4.31003216,21.5752337,42.1299684,19.3930096,39.6746395,6.13717037,8.48234693,11.4112414,18.1626798,16.4212011,6.36508936,8.79028764,23.1960385,14.0110924,181.715205,39.7069046,21.0737136,26.9749472,18.3048448,36.7074679,18.5079738,22.518794,26.3575788,57.1069322,31.0806945,17.5397772,11.182937,6.29789109,9.41790723,12.463508,43.2013987,10.0444978,26.3908327,20.7115269,19.281301,8.03978782,5.05517939,14.0612951,12.4372776,12.6839771,9.79354063,17.7252164,57.9712669,10.4970353,25.1568842,20.0762626,16.3526005,63.9168339,42.4571091,19.5677932,11.8701156,11.7354534,60.9160619,60.4283975,29.9133144,87.5994515,51.676612,13.1022182,13.1103592,29.1036824,16.6776413,25.3430243,129.704933,69.1849802,155.516511,13.8374865,31.015936,35.4348424,24.6180295,32.0368875,32.4055994,21.5060411,24.1078962,21.837055,89.7767473,17.7047876,63.7570172,61.6263854,30.8347809,62.3633355,10.9963346,37.0861236,26.394429,77.5570819,29.6463996,39.8668353,64.2246258,56.6692965,29.5769609,72.6053599,35.039095,22.1941964,20.5263666,28.5081171,18.7149107,23.6185352,50.8732245,29.8859014,42.3946544,78.2155865,97.0857206,16.0103771,166.568788,19.3798815,38.9854351,444.952546,50.2320839,84.8867705,89.5068006,111.080589,31.3344678,42.7724476,79.9825813,117.383846,85.1735406,96.0192724,80.3376729,50.6395052,41.2773278,54.2815102,73.0918743,25.1789145,20.8267735,46.111002,78.8774273,20.1950518,120.432424,301.67039,0.0586534497,0.0611318037,0.0597172208,0.0615158172,0.064245163,0.0652299704,0.0644154261,0.0647695637,0.0631947183,185.490985,368.770942])

In [81]:
price_array = np.array([[849],
       [845],
       [845],
       [845],
       [845],
       [859],
       [855],
       [870],
       [870],
       [870],
       [870],
       [868],
       [868],
       [868],
       [868],
       [870],
       [870],
       [870],
       [861],
       [870],
       [870],
       [870],
       [870],
       [870],
       [870],
       [865],
       [870],
       [870],
       [865],
       [859],
       [859],
       [859],
       [859],
       [859],
       [859],
       [861],
       [861],
       [861],
       [861],
       [861],
       [856],
       [856],
       [856],
       [856],
       [856],
       [856],
       [862],
       [862],
       [856],
       [856],
       [868],
       [868],
       [868],
       [871],
       [871],
       [871],
       [871],
       [871],
       [871],
       [871],
       [868],
       [871],
       [871],
       [871],
       [862],
       [862],
       [862],
       [862],
       [862],
       [868],
       [868],
       [875],
       [875],
       [875],
       [875],
       [871],
       [875],
       [875],
       [875],
       [875],
       [875],
       [880],
       [882],
       [884],
       [881],
       [881],
       [884],
       [884],
       [886],
       [886],
       [886],
       [886],
       [886],
       [886],
       [886],
       [886],
       [889],
       [884],
       [881],
       [881],
       [876],
       [881],
       [881],
       [881],
       [872],
       [872],
       [880],
       [880],
       [882],
       [885],
       [884],
       [884],
       [888],
       [880],
       [880],
       [888],
       [888],
       [884],
       [888],
       [888],
       [885],
       [880],
       [880],
       [880],
       [880],
       [880],
       [880],
       [885],
       [885],
       [885],
       [885],
       [878],
       [878],
       [883],
       [883],
       [883],
       [883],
       [883],
       [883],
       [883],
       [883],
       [883],
       [874],
       [880],
       [872],
       [877],
       [877],
       [877],
       [877],
       [877],
       [877],
       [877],
       [877],
       [877],
       [870],
       [870],
       [879],
       [879],
       [868],
       [868],
       [868],
       [868],
       [868],
       [868],
       [873],
       [873],
       [873],
       [873],
       [873],
       [873],
       [873],
       [873],
       [869],
       [869],
       [869],
       [869],
       [870],
       [870],
       [870],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [865],
       [865],
       [865],
       [865],
       [865],
       [865],
       [865],
       [865],
       [865],
       [865],
       [862],
       [862],
       [862],
       [862],
       [862],
       [862],
       [862],
       [862],
       [862],
       [867],
       [867],
       [867],
       [867],
       [867],
       [867],
       [859],
       [859],
       [859],
       [859],
       [865],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [866],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [866],
       [866],
       [860],
       [860],
       [860],
       [860],
       [860],
       [860],
       [860],
       [863],
       [863],
       [863],
       [863],
       [863],
       [863],
       [861],
       [861],
       [860],
       [857],
       [857],
       [857],
       [857],
       [859],
       [859],
       [859],
       [855],
       [858],
       [858],
       [854],
       [854],
       [854],
       [854],
       [854],
       [854],
       [854],
       [858],
       [858],
       [858],
       [858],
       [853],
       [853],
       [853],
       [853],
       [853],
       [853],
       [853],
       [853],
       [858],
       [858],
       [858],
       [861],
       [861],
       [861],
       [861],
       [856],
       [856],
       [855],
       [855],
       [855],
       [855],
       [855],
       [855],
       [859],
       [860],
       [860],
       [860],
       [863],
       [863],
       [864],
       [864],
       [864],
       [864],
       [864],
       [864],
       [864],
       [864],
       [864],
       [864],
       [864],
       [861],
       [861],
       [861],
       [863],
       [862],
       [862],
       [862],
       [864],
       [864],
       [864],
       [864],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [861],
       [860],
       [860],
       [860],
       [860],
       [860],
       [860],
       [860],
       [859],
       [859],
       [859],
       [860],
       [860],
       [860],
       [863],
       [863],
       [864],
       [866],
       [866],
       [867],
       [867],
       [867],
       [867],
       [867],
       [868],
       [869],
       [869],
       [869],
       [869],
       [869],
       [870],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [870],
       [870],
       [876],
       [878],
       [878],
       [878],
       [878],
       [878],
       [878],
       [878],
       [873],
       [873],
       [873],
       [873],
       [873],
       [876],
       [876],
       [876],
       [876],
       [876],
       [876],
       [880],
       [880],
       [880],
       [880],
       [880],
       [880],
       [880],
       [879],
       [879],
       [877],
       [882],
       [882],
       [882],
       [885],
       [885],
       [885],
       [885],
       [885],
       [885],
       [884],
       [884],
       [882],
       [881],
       [881],
       [881],
       [881],
       [883],
       [880],
       [880],
       [885],
       [884],
       [877],
       [885],
       [885],
       [882],
       [885],
       [885],
       [880],
       [880],
       [880],
       [881],
       [881],
       [881],
       [884],
       [884],
       [884],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [875],
       [877],
       [877],
       [882],
       [882],
       [882],
       [885],
       [885],
       [885],
       [885],
       [885],
       [885],
       [885],
       [885],
       [885],
       [885],
       [885],
       [879]])

In [83]:
env_configuration = {'price_array': price_array, 'turbulence_array': turbulence_array,'if_train': True}

In [84]:
trainer = PPOTrainer(
    config={
        # Env class to use (here: our gym.Env sub-class from above).
        "env": StockTradingEnv,
        # Config dict to be passed to our custom env's constructor.
        "env_config": env_configuration,
        # Parallelize environment rollouts.
        "num_workers": 0,
        "model":{
            "use_lstm": True,
            "lstm_cell_size": int(256),
            "lstm_use_prev_action": True,
            "lstm_use_prev_reward": True,
        }
    }
)


Install gputil for GPU system monitoring.


In [85]:
for i in range(5):
    results = trainer.train()
    print(f"Iter: {i}; avg. reward={results['episode_reward_mean']}")

Iter: 0; avg. reward=-15.656882054500713
Iter: 1; avg. reward=-14.643561861698235
Iter: 2; avg. reward=-13.974187483991054
Iter: 3; avg. reward=-13.268664536086169
Iter: 4; avg. reward=-13.325204951462647


In [86]:
env = StockTradingEnv(env_configuration)
# Get the initial observation (some value between -10.0 and 10.0).
obs = env.reset()
state = [np.zeros([256], dtype=np.float64) for _ in range(2)]
prev_a = 0.0
prev_r = 0.0
done = False
total_reward = 0.0
# Play one episode.
while not done:
    # Compute a single action, given the current observation
    # from the environment.
    action, state, _ = trainer.compute_single_action(obs, state, prev_action=prev_a, prev_reward=prev_r)
    # Apply the computed action in the environment.
    obs, reward, done, info = env.step(action)
    prev_a = action
    prev_r = reward
    # Sum up rewards for reporting purposes.
    total_reward += reward
# Report results.
print(f"Played 1 episode; total-reward={total_reward}")

2022-04-25 08:56:10,868	ERROR tf_run_builder.py:52 -- Error fetching: [<tf.Tensor 'default_policy/cond_1/Merge:0' shape=(?, 1) dtype=float32>, <tf.Tensor 'default_policy/model_1/lstm/while/Exit_3:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'default_policy/model_1/lstm/while/Exit_4:0' shape=(?, 256) dtype=float32>, {'action_prob': <tf.Tensor 'default_policy/Exp_1:0' shape=(?,) dtype=float32>, 'action_logp': <tf.Tensor 'default_policy/cond_2/Merge:0' shape=(?,) dtype=float32>, 'action_dist_inputs': <tf.Tensor 'default_policy/Reshape_3:0' shape=(?, 2) dtype=float32>, 'vf_preds': <tf.Tensor 'default_policy/Reshape_4:0' shape=(?,) dtype=float32>}], feed_dict={<tf.Tensor 'default_policy/obs:0' shape=(?, 6) dtype=float32>: array([[242.4522  ,   0.      ,   0.      ,  13.265625,   0.296875,
          0.      ]], dtype=float32), <tf.Tensor 'default_policy/state_in_0:0' shape=(?, 256) dtype=float32>: array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.,

ValueError: ignored