In [6]:
import random
import json
import gym
from gym import spaces
import pandas as pd
import numpy as np

import warnings
warnings.filterwarnings('ignore')

In [7]:
MAX_ACCOUNT_BALANCE = 2147483647
MAX_NUM_SHARES = 2147483647
MAX_SHARE_PRICE = 5000
MAX_OPEN_POSITIONS = 5
MAX_STEPS = 20000

INITIAL_ACCOUNT_BALANCE = 10000

In [8]:
class StockTradingEnv(gym.Env):
    """A stock trading environment for OpenAI gym"""
    metadata = {'render.modes': ['human']}

    def __init__(self, df):
        super(StockTradingEnv, self).__init__()

        self.df = df
        self.reward_range = (0, MAX_ACCOUNT_BALANCE)

        # Actions of the format Buy x%, Sell x%, Hold, etc.
        self.action_space = spaces.Box(
            low=np.array([0, 0]), high=np.array([3, 1]), dtype=np.float16)

        # Prices contains the OHCL values for the last five prices
        self.observation_space = spaces.Box(
            low=0, high=1, shape=(6, 6), dtype=np.float16)

    def _next_observation(self):
        # Get the stock data points for the last 5 days and scale to between 0-1
        frame = np.array([
            self.df.loc[self.current_step: self.current_step +
                        5, 'Open'].values / MAX_SHARE_PRICE,
            self.df.loc[self.current_step: self.current_step +
                        5, 'High'].values / MAX_SHARE_PRICE,
            self.df.loc[self.current_step: self.current_step +
                        5, 'Low'].values / MAX_SHARE_PRICE,
            self.df.loc[self.current_step: self.current_step +
                        5, 'Close'].values / MAX_SHARE_PRICE,
            self.df.loc[self.current_step: self.current_step +
                        5, 'Volume'].values / MAX_NUM_SHARES,
        ])

        # Append additional data and scale each value to between 0-1
        obs = np.append(frame, [[
            self.balance / MAX_ACCOUNT_BALANCE,
            self.max_net_worth / MAX_ACCOUNT_BALANCE,
            self.shares_held / MAX_NUM_SHARES,
            self.cost_basis / MAX_SHARE_PRICE,
            self.total_shares_sold / MAX_NUM_SHARES,
            self.total_sales_value / (MAX_NUM_SHARES * MAX_SHARE_PRICE),
        ]], axis=0)

        return obs

    def _take_action(self, action):
        # Set the current price to a random price within the time step
        current_price = random.uniform(
            self.df.loc[self.current_step, "Open"], self.df.loc[self.current_step, "Close"])

        action_type = action[0]
        amount = action[1]

        if action_type < 1:
            # Buy amount % of balance in shares
            total_possible = int(self.balance / current_price)
            shares_bought = int(total_possible * amount)
            prev_cost = self.cost_basis * self.shares_held
            additional_cost = shares_bought * current_price

            self.balance -= additional_cost
            self.cost_basis = (
                prev_cost + additional_cost) / (self.shares_held + shares_bought)
            self.shares_held += shares_bought

        elif action_type < 2:
            # Sell amount % of shares held
            shares_sold = int(self.shares_held * amount)
            self.balance += shares_sold * current_price
            self.shares_held -= shares_sold
            self.total_shares_sold += shares_sold
            self.total_sales_value += shares_sold * current_price

        self.net_worth = self.balance + self.shares_held * current_price

        if self.net_worth > self.max_net_worth:
            self.max_net_worth = self.net_worth

        if self.shares_held == 0:
            self.cost_basis = 0

    def step(self, action):
        # Execute one time step within the environment
        self._take_action(action)

        self.current_step += 1

        if self.current_step > len(self.df.loc[:, 'Open'].values) - 6:
            self.current_step = 0

        delay_modifier = (self.current_step / MAX_STEPS)

        reward = self.balance * delay_modifier
        done = self.net_worth <= 0

        obs = self._next_observation()

        return obs, reward, done, {}

    def reset(self):
        # Reset the state of the environment to an initial state
        self.balance = INITIAL_ACCOUNT_BALANCE
        self.net_worth = INITIAL_ACCOUNT_BALANCE
        self.max_net_worth = INITIAL_ACCOUNT_BALANCE
        self.shares_held = 0
        self.cost_basis = 0
        self.total_shares_sold = 0
        self.total_sales_value = 0

        # Set the current step to a random point within the data frame
        self.current_step = random.randint(
            0, len(self.df.loc[:, 'Open'].values) - 6)

        return self._next_observation()

    def render(self, mode='human', close=False):
        # Render the environment to the screen
        profit = self.net_worth - INITIAL_ACCOUNT_BALANCE

        print(f'Step: {self.current_step}')
        print(f'Balance: {self.balance}')
        print(
            f'Shares held: {self.shares_held} (Total sold: {self.total_shares_sold})')
        print(
            f'Avg cost for held shares: {self.cost_basis} (Total sales value: {self.total_sales_value})')
        print(
            f'Net worth: {self.net_worth} (Max net worth: {self.max_net_worth})')
        print(f'Profit: {profit}')

In [9]:
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

In [11]:
df = pd.read_csv('./data/AAPL.csv')
df = df.sort_values('Date')

# The algorithms require a vectorized environment to run
env = DummyVecEnv([lambda: StockTradingEnv(df)])

model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=20000)

obs = env.reset()
for i in range(2000):
    action, _states = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    if i % 100 == 0:
        env.render()

--------------------------------------
| approxkl           | 6.0495677e-07 |
| clipfrac           | 0.0           |
| explained_variance | -8.34e-07     |
| fps                | 82            |
| n_updates          | 1             |
| policy_entropy     | 2.838792      |
| policy_loss        | -0.0002163588 |
| serial_timesteps   | 128           |
| time_elapsed       | 1.1e-05       |
| total_timesteps    | 128           |
| value_loss         | 2658634.2     |
--------------------------------------
--------------------------------------
| approxkl           | 2.62119e-07   |
| clipfrac           | 0.0           |
| explained_variance | 1.31e-06      |
| fps                | 588           |
| n_updates          | 2             |
| policy_entropy     | 2.8405898     |
| policy_loss        | -8.391682e-06 |
| serial_timesteps   | 256           |
| time_elapsed       | 1.54          |
| total_timesteps    | 256           |
| value_loss         | 5069503.5     |
-------------------------

--------------------------------------
| approxkl           | 0.00025323714 |
| clipfrac           | 0.0           |
| explained_variance | 2.25e-05      |
| fps                | 575           |
| n_updates          | 14            |
| policy_entropy     | 2.844299      |
| policy_loss        | 0.0016092756  |
| serial_timesteps   | 1792          |
| time_elapsed       | 4.26          |
| total_timesteps    | 1792          |
| value_loss         | 654384500.0   |
--------------------------------------
-------------------------------------
| approxkl           | 8.701298e-05 |
| clipfrac           | 0.0          |
| explained_variance | -8.94e-06    |
| fps                | 599          |
| n_updates          | 15           |
| policy_entropy     | 2.8427773    |
| policy_loss        | 0.0010756388 |
| serial_timesteps   | 1920         |
| time_elapsed       | 4.49         |
| total_timesteps    | 1920         |
| value_loss         | 1320706800.0 |
-------------------------------------

---------------------------------------
| approxkl           | 7.480756e-05   |
| clipfrac           | 0.0            |
| explained_variance | 0.000288       |
| fps                | 512            |
| n_updates          | 31             |
| policy_entropy     | 2.8419755      |
| policy_loss        | -0.00024281652 |
| serial_timesteps   | 3968           |
| time_elapsed       | 8.05           |
| total_timesteps    | 3968           |
| value_loss         | 14423764.0     |
---------------------------------------
--------------------------------------
| approxkl           | 0.00023474726 |
| clipfrac           | 0.0           |
| explained_variance | -2.85e-05     |
| fps                | 523           |
| n_updates          | 32            |
| policy_entropy     | 2.8386154     |
| policy_loss        | -0.0027171748 |
| serial_timesteps   | 4096          |
| time_elapsed       | 8.3           |
| total_timesteps    | 4096          |
| value_loss         | 5699711.5     |
------------

--------------------------------------
| approxkl           | 0.0011787857  |
| clipfrac           | 0.0           |
| explained_variance | 8.82e-06      |
| fps                | 590           |
| n_updates          | 48            |
| policy_entropy     | 2.8255663     |
| policy_loss        | -0.0019873218 |
| serial_timesteps   | 6144          |
| time_elapsed       | 12.2          |
| total_timesteps    | 6144          |
| value_loss         | 436471650.0   |
--------------------------------------
-------------------------------------
| approxkl           | 0.0005582856 |
| clipfrac           | 0.0          |
| explained_variance | -1.06e-05    |
| fps                | 547          |
| n_updates          | 49           |
| policy_entropy     | 2.8267405    |
| policy_loss        | 0.0021762028 |
| serial_timesteps   | 6272         |
| time_elapsed       | 12.4         |
| total_timesteps    | 6272         |
| value_loss         | 982304830.0  |
-------------------------------------

--------------------------------------
| approxkl           | 9.172194e-05  |
| clipfrac           | 0.0           |
| explained_variance | -2.15e-06     |
| fps                | 598           |
| n_updates          | 65            |
| policy_entropy     | 2.8311324     |
| policy_loss        | 0.00064480654 |
| serial_timesteps   | 8320          |
| time_elapsed       | 16            |
| total_timesteps    | 8320          |
| value_loss         | 4127219200.0  |
--------------------------------------
--------------------------------------
| approxkl           | 0.0023020024  |
| clipfrac           | 0.001953125   |
| explained_variance | -3.58e-07     |
| fps                | 601           |
| n_updates          | 66            |
| policy_entropy     | 2.8318894     |
| policy_loss        | -0.0010496795 |
| serial_timesteps   | 8448          |
| time_elapsed       | 16.2          |
| total_timesteps    | 8448          |
| value_loss         | 3495198200.0  |
-------------------------

--------------------------------------
| approxkl           | 0.0013287942  |
| clipfrac           | 0.0           |
| explained_variance | -1.31e-06     |
| fps                | 595           |
| n_updates          | 82            |
| policy_entropy     | 2.8228345     |
| policy_loss        | 0.0050444202  |
| serial_timesteps   | 10496         |
| time_elapsed       | 19.6          |
| total_timesteps    | 10496         |
| value_loss         | 25680228000.0 |
--------------------------------------
---------------------------------------
| approxkl           | 0.0005310024   |
| clipfrac           | 0.0            |
| explained_variance | 1.79e-07       |
| fps                | 603            |
| n_updates          | 83             |
| policy_entropy     | 2.8215232      |
| policy_loss        | -0.00080943573 |
| serial_timesteps   | 10624          |
| time_elapsed       | 19.9           |
| total_timesteps    | 10624          |
| value_loss         | 38453117000.0  |
-------------

---------------------------------------
| approxkl           | 2.5862806e-05  |
| clipfrac           | 0.0            |
| explained_variance | -2.38e-07      |
| fps                | 589            |
| n_updates          | 98             |
| policy_entropy     | 2.8207214      |
| policy_loss        | 0.00068474584  |
| serial_timesteps   | 12544          |
| time_elapsed       | 23.1           |
| total_timesteps    | 12544          |
| value_loss         | 569218560000.0 |
---------------------------------------
---------------------------------------
| approxkl           | 0.00306934     |
| clipfrac           | 0.033203125    |
| explained_variance | 2.98e-07       |
| fps                | 589            |
| n_updates          | 99             |
| policy_entropy     | 2.820898       |
| policy_loss        | -0.0074388143  |
| serial_timesteps   | 12672          |
| time_elapsed       | 23.3           |
| total_timesteps    | 12672          |
| value_loss         | 606599500000.0 |


--------------------------------------
| approxkl           | 0.0006950877  |
| clipfrac           | 0.0           |
| explained_variance | 4.77e-07      |
| fps                | 599           |
| n_updates          | 114           |
| policy_entropy     | 2.818776      |
| policy_loss        | -0.0008511363 |
| serial_timesteps   | 14592         |
| time_elapsed       | 26.5          |
| total_timesteps    | 14592         |
| value_loss         | 6844143600.0  |
--------------------------------------
--------------------------------------
| approxkl           | 0.00065737707 |
| clipfrac           | 0.0           |
| explained_variance | 1.19e-07      |
| fps                | 603           |
| n_updates          | 115           |
| policy_entropy     | 2.8174565     |
| policy_loss        | -0.0017994004 |
| serial_timesteps   | 14720         |
| time_elapsed       | 26.7          |
| total_timesteps    | 14720         |
| value_loss         | 11464653000.0 |
-------------------------

---------------------------------------
| approxkl           | 4.139446e-05   |
| clipfrac           | 0.0            |
| explained_variance | 5.96e-08       |
| fps                | 599            |
| n_updates          | 130            |
| policy_entropy     | 2.8151846      |
| policy_loss        | -0.0008798352  |
| serial_timesteps   | 16640          |
| time_elapsed       | 29.9           |
| total_timesteps    | 16640          |
| value_loss         | 524981530000.0 |
---------------------------------------
---------------------------------------
| approxkl           | 0.0004987514   |
| clipfrac           | 0.0            |
| explained_variance | 0              |
| fps                | 603            |
| n_updates          | 131            |
| policy_entropy     | 2.8147333      |
| policy_loss        | 9.734894e-05   |
| serial_timesteps   | 16768          |
| time_elapsed       | 30.1           |
| total_timesteps    | 16768          |
| value_loss         | 926619900000.0 |


----------------------------------------
| approxkl           | 3.6758928e-07   |
| clipfrac           | 0.0             |
| explained_variance | 0               |
| fps                | 598             |
| n_updates          | 146             |
| policy_entropy     | 2.8159616       |
| policy_loss        | -4.3098e-05     |
| serial_timesteps   | 18688           |
| time_elapsed       | 33.4            |
| total_timesteps    | 18688           |
| value_loss         | 1960809900000.0 |
----------------------------------------
----------------------------------------
| approxkl           | 1.3227126e-08   |
| clipfrac           | 0.0             |
| explained_variance | 0               |
| fps                | 604             |
| n_updates          | 147             |
| policy_entropy     | 2.8159707       |
| policy_loss        | 2.2591557e-06   |
| serial_timesteps   | 18816           |
| time_elapsed       | 33.6            |
| total_timesteps    | 18816           |
| value_loss    

Step: 1860
Balance: 11164.226201026897
Shares held: 0 (Total sold: 23892)
Avg cost for held shares: 0 (Total sales value: 519336.95104736544)
Net worth: 11164.226201026897 (Max net worth: 13497.093043230021)
Profit: 1164.226201026897
Step: 1960
Balance: 13692.754038085559
Shares held: 0 (Total sold: 24543)
Avg cost for held shares: 0 (Total sales value: 549701.8894804294)
Net worth: 13692.754038085559 (Max net worth: 13726.45979354484)
Profit: 3692.754038085559
Step: 2060
Balance: 12464.907727671965
Shares held: 37 (Total sold: 25486)
Avg cost for held shares: 65.89820357918059 (Total sales value: 614982.009313867)
Net worth: 14894.141894325181 (Max net worth: 16055.557382796924)
Profit: 4894.141894325181
Step: 2160
Balance: 4826.5593806908155
Shares held: 187 (Total sold: 26038)
Avg cost for held shares: 51.2736931684249 (Total sales value: 646521.983641735)
Net worth: 17702.73683257217 (Max net worth: 17702.73683257217)
Profit: 7702.736832572169
Step: 2260
Balance: 19673.224180557165