In [1]:
import os
import sys
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_theme()

In [2]:
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

In [3]:
from src.strategy.model import Model
from src.strategy.environment import Environment
from src.strategy.agent import Agent
from src.strategy.buffer import Buffer
from src.utils import get_config, read_file
config = get_config.read_yaml()

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = len(config['data']['symbols']) * config['data']['num_features']
gamma = config['hyperparameters']['gamma']
gae_lambda = config['hyperparameters']['gae_lambda']
clip_epsilon = config['hyperparameters']['clip_epsilon']
value_loss_coef = config['hyperparameters']['value_loss_coef']
entropy_loss_coef = config['hyperparameters']['entropy_loss_coef']
batch_size = config['hyperparameters']['batch_size']
epochs = config['hyperparameters']['num_epochs']
lr = config['hyperparameters']['learning_rate']
seq_len = config['hyperparameters']['seq_len']
rollout_steps = config['hyperparameters']['rollout_steps']
print(f'input_dim: {input_dim}')
print(f'gamma: {gamma}')
print(f'gae_lambda: {gae_lambda}')
print(f'clip_epsilon: {clip_epsilon}')
print(f'value_loss_coef: {value_loss_coef}')
print(f'entropy_loss_coef: {entropy_loss_coef}')
print(f'batch_size: {batch_size}')
print(f'epochs: {epochs}')
print(f'lr: {lr}')
print(f'seq_len: {seq_len}')
print(f'rollout_steps: {rollout_steps}')
print(f'device: {device}')

input_dim: 126
gamma: 0.99
gae_lambda: 0.95
clip_epsilon: 0.2
value_loss_coef: 0.5
entropy_loss_coef: 0.01
batch_size: 128
epochs: 10
lr: 0.001
seq_len: 72
rollout_steps: 2048
device: cuda


In [5]:
data = read_file.read_merged_training_data()
data

Unnamed: 0_level_0,"('open', 'ETH')","('high', 'ETH')","('low', 'ETH')","('close', 'ETH')","('volume', 'ETH')","('rsi', 'ETH')","('sma-50', 'ETH')","('sma-100', 'ETH')","('sma-200', 'ETH')","('ema-50', 'ETH')",...,"('volume', 'XLM')","('rsi', 'XLM')","('sma-50', 'XLM')","('sma-100', 'XLM')","('sma-200', 'XLM')","('ema-50', 'XLM')","('ema-100', 'XLM')","('ema-200', 'XLM')","('atr', 'XLM')","('adx', 'XLM')"
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2020-10-01 00:00:00,359.83,361.27,358.87,359.35,22944.52884,59.583614,355.9574,356.6643,349.52405,356.582296,...,3151658.0,67.676653,0.073433,0.073611,0.072531,0.073530,0.073450,0.073632,0.000659,20.156512
2020-10-01 01:00:00,359.30,361.53,358.76,361.40,17489.77540,65.054769,356.1062,356.7545,349.61270,356.771226,...,1100106.9,67.908916,0.073469,0.073631,0.072548,0.073590,0.073482,0.073646,0.000625,21.776751
2020-10-01 02:00:00,361.41,363.72,361.41,362.61,30783.19186,67.823475,356.2654,356.8490,349.70975,357.000198,...,7302696.5,71.704263,0.073507,0.073650,0.072567,0.073668,0.073523,0.073665,0.000652,24.022470
2020-10-01 03:00:00,362.61,363.16,361.81,362.33,15631.06013,66.510263,356.4370,356.9331,349.80350,357.209209,...,6992099.3,58.169077,0.073527,0.073658,0.072583,0.073706,0.073545,0.073675,0.000688,24.989250
2020-10-01 04:00:00,362.33,363.98,362.06,362.31,20638.86118,66.411350,356.6054,356.9585,349.89180,357.409240,...,3004172.4,56.831370,0.073546,0.073663,0.072597,0.073738,0.073564,0.073683,0.000670,25.886974
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2024-09-29 19:00:00,2663.79,2667.63,2663.10,2666.78,3272.03330,53.187062,2672.7944,2651.1065,2634.73640,2663.166143,...,1169121.0,64.805842,0.100032,0.098768,0.097622,0.099957,0.099081,0.097998,0.000992,27.584291
2024-09-29 20:00:00,2666.77,2673.41,2658.01,2660.89,7578.82390,48.813728,2672.0144,2651.7119,2635.23155,2663.076883,...,2842498.0,55.778534,0.100096,0.098817,0.097642,0.100018,0.099129,0.098032,0.001014,27.504313
2024-09-29 21:00:00,2660.89,2668.77,2660.12,2666.63,3511.73450,52.879972,2671.3196,2652.6241,2635.71270,2663.216221,...,1356436.0,57.377033,0.100142,0.098875,0.097663,0.100088,0.099182,0.098070,0.000984,27.430048
2024-09-29 22:00:00,2666.62,2671.52,2655.50,2655.82,6182.30880,45.542371,2670.5592,2653.3542,2636.15395,2662.926173,...,3521628.0,56.642043,0.100190,0.098931,0.097682,0.100151,0.099232,0.098106,0.000971,26.933674


In [6]:
model = Model(input_dim).to(device)
model

Model(
  (lstm): LSTM(126, 256, batch_first=True)
  (actor_head): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=20, bias=True)
  )
  (critic_head): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [7]:
h_0, c_0 = model.init_hidden_state(batch_size, device)
h_0.shape

torch.Size([1, 128, 256])

In [8]:
agent = Agent(model,
              device=device,
              learning_rate=lr,
              gamma=gamma,
              gae_lambda=gae_lambda,
              clip_epsilon=clip_epsilon,
              epochs=epochs,
              batch_size=batch_size,
              sequence_length=seq_len,
              value_loss_coef=value_loss_coef,
              entropy_loss_coef=entropy_loss_coef,
              )
agent

<src.strategy.agent.Agent at 0x24a0a8ac1a0>

In [9]:
env = Environment(data)
env

Environment initialized with 35020 timesteps.
Observation space dim: 126
Action space dim: 10


<src.strategy.environment.Environment at 0x24a2acfc050>

In [10]:
buffer = Buffer(rollout_steps)
buffer

<src.strategy.buffer.Buffer at 0x24a2acfc1a0>

In [11]:
total_timesteps = 1_000_000
global_timestep_counter = 0
update_counter = 0

In [None]:
while global_timestep_counter < total_timesteps:
    pass