In [1]:
import numpy as np
import pandas as pd

In [4]:
class TradingEnv():
    def __init__(self, train_data, init_invest=20000):
        # data
        self.stock_price_history = np.around(train_data) # round up to integer to reduce state space
        self.n_step, self.n_stock = self.stock_price_history.shape
        self.n_stock = 1

        # instance attributes
        self.init_invest = init_invest
        self.cur_step = None
        self.stock_owned = None
        self.stock_price = None
        self.cash_in_hand = None

        # action space
        self.action_space = [0,1]

        self._reset()

    def _reset(self):
        self.cur_step = 0
        self.stock_owned = 0 * self.n_stock
        self.stock_price = self.stock_price_history.iloc[self.cur_step,0]
        self.cash_in_hand = self.init_invest
        return self._get_obs()


    def _step(self, action, holding):
        prev_val = self._get_val()
        self.cur_step += 1
        self.stock_price = self.stock_price_history.iloc[self.cur_step,0] # update price
        self._trade(action,holding)
        cur_val = self._get_val()
        reward = cur_val - prev_val
        done = self.cur_step == self.n_step - 1
        info = {'cur_val': cur_val}
        if action:
            if holding:
                holding=0
            else:
                holding=1
        next_state = self._get_state()
        next_state.append(holding)
        next_state.append(reward)
        return next_state, self._get_obs(), reward, done,holding


    def _get_state(self):
        _state = []
        _state.append(self.stock_price_history.iloc[self.cur_step,1])
        _state.append(self.stock_price_history.iloc[self.cur_step,2])
        _state.append(self.stock_price_history.iloc[self.cur_step,3])
        return _state
    
    def _get_obs(self):
        obs = []
        obs.append(self.stock_owned)
        obs.append(self.stock_price)
        obs.append(self.cash_in_hand)
        return obs


    def _get_val(self):
        return np.sum(self.stock_owned * self.stock_price) + self.cash_in_hand


    def _trade(self, action, holding):
        
        if action:
            if holding:
                sell_index =1
                buy_index=0
            else:
                buy_index=1
                sell_index =0

            if sell_index:
                self.cash_in_hand += self.stock_price * self.stock_owned
                self.stock_owned = 0
            if buy_index:
                can_buy = True
                while can_buy:
                    if self.cash_in_hand > self.stock_price:
                        self.stock_owned += 1 # buy one share
                        self.cash_in_hand -= self.stock_price
                    else:
                        can_buy = False