In [1]:

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
import gym
import numpy as np
from ppo_utils import *
from network import *
from data import *
from utils import *
import time
import json
params = {'action_std' : 0.5,
         'eps_clip' : 0.2,
         'K_epochs' : 80,
         'gamma' : 0.99,
         'lr' : 0.0003,
         'action_dim' : 1,
         'rl_latent' : 50,
         'device' : torch.device("cuda:0" if torch.cuda.is_available() else "cpu")}

dir = "D/"
path = dir + "model_198800.pt"#"model_175200.pt"

with open(dir + 'params.json', 'r') as fp:
    params.update(json.load(fp))
    
params['state_dim'] = params['dim_model']



print("Loading : ")
#data = load_data(['2018', '2019'], ['BTCUSDT', 'ETHUSDT','LTCUSDT'], '5m')
data = load_data(['2020'], ['BTCUSDT', 'ETHUSDT','LTCUSDT'], '5m')

memory = Memory()
ppo = PPO(params)

Loading : 
done


In [18]:
class TradingEnv(nn.Module):
    def __init__(self, params, data, path):
        super(TradingEnv, self).__init__()
        
        self.model = torch.load(path)['model_state']
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()
        del self.model.fc_out
        self.model.fc_out = lambda x : x
        self.quantiles = torch.FloatTensor(params['quantiles']).to(params['device'])
        
        self.data = data
        
        self.batch_size = 1
        increment = 2
        
        self.indexer = Indexer(1, data[0].shape[0] - (params['past_seq_len'] + params['future_seq_len'] + 1)
                               , self.batch_size, random = False, increment = increment)
        
        self.data_gens = []
        for idx, d in enumerate(data):
            self.data_gens.append(get_batches(d, params['past_seq_len'], 
                        params['future_seq_len'], params['continuous_columns'], params['discrete_columns'], 
                        params['target_columns'], batch_size = self.batch_size, indexer = self.indexer))#, norm = train_data[idx]))
        
        self.reset()
        
        
    def get_state(self):
         
        net_out, _, _ = forward_pass(self.model, self.data_gens, 
                                     self.batch_size, self.quantiles, self.indexer, loss = False)
        #only use btc for now
        net_out = torch.mean(net_out.detach()[0], dim = 0).unsqueeze(0)
        
        return net_out
    
    def step(self, action): #action[0,1] represents percentage invested
        
        
        #only use btc for now
        current_price = self.data[0].iloc[self.indexer.indices]['Close'].values[0]
        self.indexer.next()   
        
        reward = 0
        
        delta_pct = action - self.pct_invested
        self.pct_invested += delta_pct
        
        delta = abs(delta_pct * self.money) #fiat units
        
        a = [current_price, delta]
        if(delta_pct > 0):
            #self.buy_orders = np.append(self.buy_orders, a.copy())
            self.buy_orders.append(a.copy()) #buying delta at current price
            self.log_buy_orders.append([self.n_step] + a.copy())
        elif(delta_pct < 0):
            #self.sell_orders = np.append(self.sell_orders, a.copy())
            self.sell_orders.append(a.copy()) #selling delta at current price
            self.log_sell_orders.append([self.n_step] + a.copy())
        
        
            
        #combine sell and buy orders, determing profit
        for idx, b in enumerate(self.buy_orders):
            if(b[1] == 0):
                continue
            for s in self.sell_orders:
                if(s[1] == 0):
                    continue
                
                delta_pct = (s[0] - b[0])/b[0]
                if(b[1] >= s[1]):
                    self.money += s[1] * delta_pct
                    
                    b[1] -= s[1]
                    s[1] = 0
                    
                else:
                    #print(b[1])
                    s[1] -= b[1]
                    b[1] = 0
                    
                    
                reward += (delta_pct)  
                
        self.buy_orders[:] = [x for x in self.buy_orders if not x[1] == 0]
        if(len(self.buy_orders) == 0):
            del self.sell_orders[:]
        self.sell_orders[:] = [x for x in self.sell_orders if not x[1] == 0]
        
                    
        self.n_step += 1    
        
        
#         print(self.log_buy_orders)
#         print(self.log_sell_orders)
        
#         print()
#         print(self.buy_orders)
#         print(self.sell_orders)
        
        return self.get_state(), reward
        
    def reset(self):
        self.indexer.indices[-1] = random.randint(self.indexer.r_bottom, self.indexer.r_top-100)
        self.n_step = 0
        self.pct_invested = 0
        
        self.money = 1 #in multiple of USD
        
        self.buy_orders = []
        self.sell_orders = []
        
        self.log_buy_orders = []
        self.log_sell_orders = []
        return self.get_state()
        
        
env = TradingEnv(params, data, path)

#env.step(0.5)

In [25]:
t = [[]]* 1
t

[[]]

In [20]:
print("Beginning training...")
log_interval = 20           # print avg reward in the interval
n_episodes = 10000        # max training episodes
max_timesteps = 200       # max timesteps in one episode

update_timestep = 400      # update policy every n timesteps


# logging variables
running_reward = 0
avg_length = 0
time_step = 0

# training loop
memory.clear_memory()
for i_episode in range(1, n_episodes+1):
    #only use btc for now
    state = env.reset()#.to(params['device']).unsqueeze(0)#.repeat([2,1])\
    #print(state.shape)
    #print(i_episode)
    #print("sdf")
    #print(state.shape)
     
        
    for t in range(max_timesteps):
        #print(t)
        time_step +=1
        # Running policy_old:
        
        action = ppo.select_action(state, memory)
        
        #print(action)
        start_time = time.time()  
        
        state, reward = env.step(action)
        #print("Episode execution time : --- %s seconds ---" % (time.time() - start_time))
        #print()
        #state = torch.FloatTensor(state).to(params['device']).unsqueeze(0)#.repeat([2,1])
        
        # Saving reward and is_terminals:
        memory.rewards.append(reward)
        memory.is_terminals.append(False)

        # update if its time
        if time_step % update_timestep == 0:
            ppo.update(memory)
            memory.clear_memory()
            time_step = 0
            
        running_reward += reward

    avg_length += t
    
    
    ##### logging ######

    # save every 500 episodes
    if i_episode % 500 == 0:
        torch.save(ppo.policy.state_dict(), './PPO_continuous_{}.pth'.format(env_name))

    # logging
    if i_episode % log_interval == 0:
        avg_length = int(avg_length/log_interval)
        running_reward = int((running_reward/log_interval))

        print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward))
        running_reward = 0
        avg_length = 0



Beginning training...
Episode execution time : --- 0.0 seconds ---
Episode execution time : --- 0.0009970664978027344 seconds ---
Episode execution time : --- 0.0009975433349609375 seconds ---
Episode execution time : --- 0.0 seconds ---
Episode execution time : --- 0.0 seconds ---
Episode execution time : --- 0.0 seconds ---
Episode execution time : --- 0.0009968280792236328 seconds ---
Episode execution time : --- 0.0009970664978027344 seconds ---
Episode execution time : --- 0.0009965896606445312 seconds ---
Episode execution time : --- 0.0009980201721191406 seconds ---
Episode execution time : --- 0.0 seconds ---
Episode execution time : --- 0.000997304916381836 seconds ---
Episode execution time : --- 0.0 seconds ---
Episode execution time : --- 0.0009970664978027344 seconds ---
Episode execution time : --- 0.0 seconds ---
Episode execution time : --- 0.000997304916381836 seconds ---
Episode execution time : --- 0.0 seconds ---
Episode execution time : --- 0.0009970664978027344 se

KeyboardInterrupt: 