In [1]:
import torch
from torch import nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datetime
from typing import Callable, Optional, Type, List, Union, Dict
import os

import gymnasium as gym
import gym_trading_env
from gym_trading_env.downloader import download
from gymnasium import spaces

from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3 import A2C, PPO
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import MlpExtractor, BaseFeaturesExtractor
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor

In [2]:
# download(exchange_names=["binance", "bitfinex2", "huobi"],
#     symbols=["BTC/USDT", "ETH/USDT"],
#     timeframe="1h",
#     dir="data",
#     since=datetime.datetime(year=2018, month=1, day=1),
# )

In [3]:
WINDOW_SIZE = 24 * 7

def preprocess(df):
    df["close_feature"] = df["close"]
    df["high_feature"] = df["high"]
    df["low_feature"] = df["low"]
    df["open_feature"] = df["open"]
    df["timestamp_growth_feature"] = df["close"] / df["open"]
    df["range_feature"] = df["high"] / df["low"]
    df["max_recent_cost_feature"] = df["volume"].rolling(24).max()
    df["min_recent_cost_feature"] = df["volume"].rolling(24).min()
    df.dropna(inplace=True)
    return df

In [4]:
env = gym.make(
    "MultiDatasetTradingEnv",
    dataset_dir="data/*.pkl",
    preprocess=preprocess,
    windows=24 * 7
)

In [5]:
def make_env():
    return gym.make(
        "MultiDatasetTradingEnv",
        dataset_dir="data/*.pkl",
        preprocess=preprocess,
        windows=WINDOW_SIZE
    )

vec_env = DummyVecEnv([make_env for _ in range(10)])

In [6]:
class CustomRNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, input_dim: int = 10,
                 hidden_dim: int = 512, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, bidirectional=True, batch_first=True)
        self.relu = nn.ReLU()
        self.flat = nn.Flatten()
        self.fc = nn.Linear(WINDOW_SIZE * hidden_dim * 2, features_dim)
        
    def forward(self, observations):
        observations = self.lstm(observations)[0]
        observations = self.relu(observations)
        observations = self.flat(observations)
        observations = self.fc(observations)
        return observations

policy_kwargs = dict(
    features_extractor_class=CustomRNN,
    features_extractor_kwargs=dict(hidden_dim = 256, features_dim = 128),
)
model = A2C("MlpPolicy", vec_env,
            policy_kwargs=policy_kwargs, verbose=1)

Using cuda device


In [7]:
eval_callback = EvalCallback(Monitor(env), best_model_save_path="./models/",
                              log_path="./logs/", eval_freq=1000,
                              n_eval_episodes=5, deterministic=True,
                              render=False)

In [8]:
model.learn(total_timesteps=5e6)

------------------------------------
| time/                 |          |
|    fps                | 297      |
|    iterations         | 100      |
|    time_elapsed       | 16       |
|    total_timesteps    | 5000     |
| train/                |          |
|    entropy_loss       | -0.677   |
|    explained_variance | 0.891    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -0.0012  |
|    value_loss         | 6.05e-05 |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 310      |
|    iterations         | 200      |
|    time_elapsed       | 32       |
|    total_timesteps    | 10000    |
| train/                |          |
|    entropy_loss       | -0.658   |
|    explained_variance | 0.289    |
|    learning_rate      | 0.0007   |
|    n_updates          | 199      |
|    policy_loss        | 0.0054   |
|    value_loss         | 0.000243 |
-

In [9]:
model.save("./models/trader.zip")

In [10]:
model.load("./models/trader.zip")

<stable_baselines3.a2c.a2c.A2C at 0x22dfd8769e0>