From ec3bc0bda366ac6eec4945812dc977aa5036a916 Mon Sep 17 00:00:00 2001
From: qml <937883273@qq.com>
Date: Mon, 20 Feb 2023 20:29:00 +0800
Subject: [PATCH] fix tutorial
---
docs/source/tutorial/tutorial1.md | 22 +-
docs/source/tutorial/tutorial2.md | 23 +-
docs/source/tutorial/tutorial5.md | 15 +-
tutorial/Tutorial1_DeepScalper.ipynb | 704 +++++++++++++++++++++++++++
tutorial/Tutorial2_EIIE.ipynb | 461 ++++++++++++++++++
tutorial/Tutorial5_HFT.ipynb | 451 +++++++++++++++++
6 files changed, 1647 insertions(+), 29 deletions(-)
create mode 100644 tutorial/Tutorial1_DeepScalper.ipynb
create mode 100644 tutorial/Tutorial2_EIIE.ipynb
create mode 100644 tutorial/Tutorial5_HFT.ipynb
diff --git a/docs/source/tutorial/tutorial1.md b/docs/source/tutorial/tutorial1.md
index cdfd3aeb..a9f257a0 100644
--- a/docs/source/tutorial/tutorial1.md
+++ b/docs/source/tutorial/tutorial1.md
@@ -1,20 +1,18 @@
# Tutorial 1: Intraday Crypto Trading with DeepScalper
+![DeepScalper.png](DeepScalper.png)
-## Task
Intraday trading is a fundamental quantitative trading task, where traders actively long/short one pre-selected financial asset within the same trading day to maximize future profit.
-## Algorithm
-DeepScalper contains 4 technical contributions which all together make it better than direct use of RL algorithms.
-- RL optimization with action branching
-- reward function with hindsight bonus
-- intraday market embedding
-- risk-aware auxiliary task
+DeepScalper use deep q network to optimize the reward sum got from reinforcement learning where a hindsight reward is used to capture the long-term porfit trends and embedding from both micro-level and macro-level market information.
-Here is the construction of the DeepScalper:
-
-
![](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/docs/source/tutorial/DeepScalper.jpg)
-
-Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/DeepScalper.ipynb) about how you can build DeepScalper in a few lines of codes using TradeMaster.
\ No newline at end of file
+## Notebook and Script
+In this notebook, we implement the training and testing process of DeepScalper based on the TradeMaster framework.
+
+[Tutorial1_DeepScalper](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial1_DeepScalper.ipynb)
+
+And this is the script for training and testing.
+
+[train.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/algorithmic_trading/train.py)
\ No newline at end of file
diff --git a/docs/source/tutorial/tutorial2.md b/docs/source/tutorial/tutorial2.md
index eaba8091..3a6a81f5 100644
--- a/docs/source/tutorial/tutorial2.md
+++ b/docs/source/tutorial/tutorial2.md
@@ -1,19 +1,20 @@
# Tutorial 2: Portfolio Management with EIIE on US stocks
-
-## Task
+![EIIE.png](EIIE.png)
Portfolio management is the action of continuous reallocation of a capital into a number of financial assets periodically.
-## Algorithm
-EIIE contains 2 technical contributions which all together make it better than direct use of RL algorithms.
-- Deterministic Policy Gradient
-- Portfolio-Vector Memor
+The framework consists of the Ensemble of Identical Independent Evaluators
+(EIIE) topology, a Portfolio-Vector Memory (PVM), an Online Stochastic Batch Learning
+(OSBL) scheme, and a fully exploiting and explicit reward function.
+
+
+
-Here is the construction of the EIIE:
-
-
![](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/docs/source/tutorial/EIIE.jpg)
-
+## Notebook and Script
+In this notebook, we implement the training and testing process of EIIE based on the TradeMaster framework.
+[Tutorial2_EIIE](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial2_EIIE.ipynb)
+And this is the script for training and testing.
-Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/EIIE.ipynb) about how you can build EIIE in a few lines of codes using TradeMaster.
\ No newline at end of file
+[train_eiie.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/portfolio_management/train_eiie.py)
\ No newline at end of file
diff --git a/docs/source/tutorial/tutorial5.md b/docs/source/tutorial/tutorial5.md
index 7ee6a7fb..6fa1f6ce 100644
--- a/docs/source/tutorial/tutorial5.md
+++ b/docs/source/tutorial/tutorial5.md
@@ -1,11 +1,14 @@
# Tutorial 5: High Frequency Trading with Double DQN
-## Task
High Frequency Trading is a fundamental quantitative trading task, where traders actively buy/sell one pre-selected financial periodically in seconds with the consideration of order execution.
-## Algorithm
-HFT_DDQN contains 2 technical contributions which all together make it better than direct use of RL algorithms.
-- Double deep q network
-- Regulator from the true q table
+HFT_DDQN use a decayed supervised regulator genereated from the real q table based on the future price information and a double q network to optimizer the portfit margine.
-Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/HFT.ipynb) about how you can build DDQN for HFT in a few lines of codes using TradeMaster.
\ No newline at end of file
+## Notebook and Script
+In this notebook, we implement the training and testing process of HFTDDQN based on the TradeMaster framework.
+
+[Tutorial5_HFT](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial5_HFT.ipynb)
+
+And this is the script for training and testing.
+
+[train.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/high_frequency_trading/train.py)
\ No newline at end of file
diff --git a/tutorial/Tutorial1_DeepScalper.ipynb b/tutorial/Tutorial1_DeepScalper.ipynb
new file mode 100644
index 00000000..816dbb4a
--- /dev/null
+++ b/tutorial/Tutorial1_DeepScalper.ipynb
@@ -0,0 +1,704 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 1: Import Packages\n",
+ "Modify the system path and load the corresponding packages and functions "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "from pathlib import Path\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "ROOT = str(Path(\"__file__\").resolve().parents[1])\n",
+ "sys.path.append(ROOT)\n",
+ "import torch\n",
+ "import argparse\n",
+ "import os.path as osp\n",
+ "from mmcv import Config\n",
+ "from trademaster.utils import replace_cfg_vals\n",
+ "from trademaster.nets.builder import build_net\n",
+ "from trademaster.environments.builder import build_environment\n",
+ "from trademaster.datasets.builder import build_dataset\n",
+ "from trademaster.agents.builder import build_agent\n",
+ "from trademaster.optimizers.builder import build_optimizer\n",
+ "from trademaster.losses.builder import build_loss\n",
+ "from trademaster.trainers.builder import build_trainer\n",
+ "from trademaster.transition.builder import build_transition"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 2: Load Configs\n",
+ "Load default config from the folder `configs/algorithmic_trading/algorithmic_trading_BTC_dqn_dqn_adam_mse.py`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n",
+ "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"algorithmic_trading\", \"algorithmic_trading_BTC_dqn_dqn_adam_mse.py\"),\n",
+ " help=\"download datasets config file path\")\n",
+ "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n",
+ "parser.add_argument(\"--test_style\", type=str, default='-1')\n",
+ "args = parser.parse_args([])\n",
+ "cfg = Config.fromfile(args.config)\n",
+ "task_name = args.task_name\n",
+ "\n",
+ "cfg = replace_cfg_vals(cfg)\n",
+ "# update test style\n",
+ "cfg.data.update({'test_style': args.test_style})\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/algorithmic_trading/algorithmic_trading_BTC_dqn_dqn_adam_mse.py): {'data': {'type': 'AlgorithmicTradingDataset', 'data_path': 'data/algorithmic_trading/BTC', 'train_path': 'data/algorithmic_trading/BTC/train.csv', 'valid_path': 'data/algorithmic_trading/BTC/valid.csv', 'test_path': 'data/algorithmic_trading/BTC/test.csv', 'test_style_path': 'data/algorithmic_trading/BTC/test_labeled_3_24_-0.15_0.15.csv', 'tech_indicator_list': ['high', 'low', 'open', 'close', 'adjcp', 'zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'backward_num_day': 5, 'forward_num_day': 5, 'test_style': '-1'}, 'environment': {'type': 'AlgorithmicTradingEnvironment'}, 'agent': {'type': 'AlgorithmicTradingDQN', 'max_step': 12345, 'reward_scale': 1, 'repeat_times': 1, 'gamma': 0.9, 'batch_size': 64, 'clip_grad_norm': 3.0, 'soft_update_tau': 0, 'state_value_tau': 0.005}, 'trainer': {'type': 'AlgorithmicTradingTrainer', 'epochs': 20, 'work_dir': 'work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse', 'seeds_list': (12345,), 'batch_size': 64, 'horizon_len': 1024, 'buffer_size': 1000000.0, 'num_threads': 8, 'if_remove': False, 'if_discrete': True, 'if_off_policy': True, 'if_keep_save': True, 'if_over_write': False, 'if_save_buffer': False}, 'loss': {'type': 'MSELoss'}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'QNet', 'state_dim': 82, 'action_dim': 3, 'dims': (64, 32), 'explore_rate': 0.25}, 'cri': None, 'transition': {'type': 'Transition'}, 'task_name': 'algorithmic_trading', 'dataset_name': 'BTC', 'optimizer_name': 'adam', 'loss_name': 'mse', 'net_name': 'dqn', 'agent_name': 'dqn', 'work_dir': 'work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse', 'batch_size': 64}"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cfg"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 3: Build Dataset\n",
+ "Build datasets from cfg defined above"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = build_dataset(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 4: Build Reinforcement Learning Environments\n",
+ "Build environments based on cfg and previously-defined dataset\n",
+ "\n",
+ "A style-test is provided as an option to test the algorithm's performance under different market conditions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"train\"))\n",
+ "valid_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"valid\"))\n",
+ "test_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"test\"))\n",
+ "if task_name.startswith(\"style_test\"):\n",
+ " test_style_environments = []\n",
+ " for i, path in enumerate(dataset.test_style_paths):\n",
+ " test_style_environments.append(build_environment(cfg, default_args=dict(dataset=dataset, task=\"test_style\",\n",
+ " style_test_path=path,\n",
+ " task_index=i)))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "valid_environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_environment"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 5: Build Net \n",
+ "Update information about the state and action dimension in the config and create nets and optimizer for DQN\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "action_dim = train_environment.action_dim\n",
+ "state_dim = train_environment.state_dim\n",
+ "\n",
+ "cfg.act.update(dict(action_dim=action_dim, state_dim=state_dim))\n",
+ "act = build_net(cfg.act)\n",
+ "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n",
+ "if cfg.cri:\n",
+ " cfg.cri.update(dict(action_dim=action_dim, state_dim=state_dim))\n",
+ " cri = build_net(cfg.cri)\n",
+ " cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))\n",
+ "else:\n",
+ " cri = None\n",
+ " cri_optimizer = None"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 6: Build Loss\n",
+ "Build loss from config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "criterion = build_loss(cfg)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 7: Build Transition\n",
+ "Build transition from config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transition = build_transition(cfg)\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 8: Build Agent\n",
+ "Build agent from config and detect device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "agent = build_agent(cfg, default_args=dict(action_dim = action_dim,\n",
+ " state_dim = state_dim,\n",
+ " act = act,\n",
+ " cri = cri,\n",
+ " act_optimizer = act_optimizer,\n",
+ " cri_optimizer = cri_optimizer,\n",
+ " criterion = criterion,\n",
+ " transition = transition,\n",
+ " device=device))"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 9: Build Trainer\n",
+ "Build trainer from config and create work directionary to save the result, model and config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "| Arguments Keep work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse\n"
+ ]
+ }
+ ],
+ "source": [
+ "if task_name.startswith(\"style_test\"):\n",
+ " trainers = []\n",
+ " for env in test_style_environments:\n",
+ " trainers.append(build_trainer(cfg, default_args=dict(train_environment=train_environment,\n",
+ " valid_environment=valid_environment,\n",
+ " test_environment=env,\n",
+ " agent=agent,\n",
+ " device=device)))\n",
+ "else:\n",
+ " trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n",
+ " valid_environment=valid_environment,\n",
+ " test_environment=test_environment,\n",
+ " agent=agent,\n",
+ " device=device))\n",
+ "\n",
+ "cfg.dump(osp.join(ROOT, cfg.work_dir, osp.basename(args.config)))"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 10: Train the Trainer\n",
+ "Train the trainer based on the config and get results from workdir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train Episode: [1/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -344.513937% | -0.000214 | 6862.304186 | 6.214777 | -27423.674698 | -0.556031 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [1/20]\n",
+ "+---------------+-------------+-------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+--------------+---------------+\n",
+ "| 36.243173% | 0.001800 | 4108.651796 | 0.583186 | 62313.081138 | 0.551188 |\n",
+ "+---------------+-------------+-------------+--------------+--------------+---------------+\n",
+ "Valid Episode Reward Sum: 56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00001.pth\n",
+ "Train Episode: [2/20]\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "| 2550.222626% | 0.000294 | 74654.187304 | 0.984371 | 2592355.042556 | 0.737448 |\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "Valid Episode: [2/20]\n",
+ "+---------------+-------------+-------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+--------------+---------------+\n",
+ "| 38.089801% | 0.002813 | 2762.654677 | 0.456196 | 83702.883876 | 0.945576 |\n",
+ "+---------------+-------------+-------------+--------------+--------------+---------------+\n",
+ "Valid Episode Reward Sum: 57332.169254\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00002.pth\n",
+ "Train Episode: [3/20]\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "| 2129.816364% | 0.000301 | 60992.463565 | 0.990535 | 2151453.831529 | 0.761704 |\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "Valid Episode: [3/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00003.pth\n",
+ "Train Episode: [4/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -2334.279839% | -0.000274 | 62958.198404 | 24.260127 | -82729.892836 | -0.704792 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [4/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.064518% | -0.001804 | 4108.595508 | 2.561829 | -14212.530944 | -0.820658 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56365.177630\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00004.pth\n",
+ "Train Episode: [5/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -1641.409323% | -0.000287 | 43526.174039 | 14.819907 | -97857.697010 | -0.740099 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [5/20]\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| -0.000000% | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0nan |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "Valid Episode Reward Sum: 0.000000\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00005.pth\n",
+ "Train Episode: [6/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -1691.332913% | -0.000290 | 46229.282308 | 36.828678 | -42306.683350 | -0.752619 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [6/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -227.824723% | -0.001990 | 3952.099453 | 2.539431 | -15219.010810 | -0.887165 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -60502.612906\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00006.pth\n",
+ "Train Episode: [7/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -2174.481852% | -0.000292 | 59661.553550 | 47.345066 | -42758.782985 | -0.759986 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [7/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -176.839873% | -0.001765 | 3528.124043 | 2.195090 | -13939.332801 | -0.796095 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -52680.509863\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00007.pth\n",
+ "Train Episode: [8/20]\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "| 1885.983659% | 0.000443 | 36656.692362 | 0.972584 | 1941130.966521 | 1.154473 |\n",
+ "+---------------+-------------+--------------+--------------+----------------+---------------+\n",
+ "Valid Episode: [8/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00008.pth\n",
+ "Train Episode: [9/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -2724.499907% | -0.000291 | 74498.020183 | 52.412219 | -48016.107359 | -0.744638 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [9/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00009.pth\n",
+ "Train Episode: [10/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -2031.685492% | -0.000290 | 54747.564770 | 43.636693 | -42363.764305 | -0.757473 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [10/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00010.pth\n",
+ "Train Episode: [11/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -2399.121800% | -0.000292 | 64957.047457 | 51.684147 | -42618.599118 | -0.761166 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [11/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.053823% | -0.001886 | 4043.229698 | 2.569442 | -14579.510155 | -0.866852 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -57221.583660\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00011.pth\n",
+ "Train Episode: [12/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -1481.424033% | -0.000284 | 39030.031859 | 31.315723 | -41117.470365 | -0.737378 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [12/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00012.pth\n",
+ "Train Episode: [13/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -788.802406% | -0.000152 | 19854.508607 | 6.946587 | -50441.493440 | -0.380641 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [13/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -170.396222% | -0.002120 | 3555.765470 | 2.254471 | -16430.348220 | -0.931926 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -60797.589909\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00013.pth\n",
+ "Train Episode: [14/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -1499.160477% | -0.000287 | 39627.327133 | 30.212022 | -43745.116737 | -0.735405 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [14/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -169.507466% | -0.002174 | 3551.437396 | 2.245583 | -16891.155853 | -0.954286 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -61053.498659\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00014.pth\n",
+ "Train Episode: [15/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -688.419270% | 0.000683 | 12404.350216 | 2.003406 | 491928.343157 | 1.901195 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [15/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -170.679616% | -0.002148 | 3483.353434 | 2.133488 | -17229.228497 | -0.957504 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -60056.333731\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00015.pth\n",
+ "Train Episode: [16/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -972.226134% | 0.000515 | 15091.307919 | 2.283666 | 395975.151491 | 1.384039 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [16/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00016.pth\n",
+ "Train Episode: [17/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -1400.570121% | -0.000281 | 36908.469387 | 29.389929 | -41092.361952 | -0.717731 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [17/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00017.pth\n",
+ "Train Episode: [18/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -744.326652% | 0.000579 | 12140.578923 | 2.151154 | 379703.556119 | 1.557721 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [18/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -180.970136% | -0.001587 | 3330.163550 | 2.008192 | -12932.140987 | -0.734850 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -41071.900165\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00018.pth\n",
+ "Train Episode: [19/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -948.403058% | -0.000274 | 24082.344472 | 19.543640 | -39251.932592 | -0.693784 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [19/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -56390.914075\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00019.pth\n",
+ "Train Episode: [20/20]\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "| -1874.527862% | -0.000289 | 51213.539849 | 40.429782 | -42627.916754 | -0.738897 |\n",
+ "+---------------+-------------+--------------+--------------+---------------+---------------+\n",
+ "Valid Episode: [20/20]\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| -133.446994% | -0.002186 | 3066.785470 | 1.822390 | -18069.448586 | -0.946982 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Valid Episode Reward Sum: -54479.743152\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00020.pth\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/best.pth\n",
+ "Resume checkpoint /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/best.pth\n",
+ "Test Best Episode\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "| 49.996525% | -0.005364 | 4336.971275 | 1.165761 | -98033.866998 | -1.900394 |\n",
+ "+---------------+-------------+-------------+--------------+---------------+---------------+\n",
+ "Test Best Episode Reward Sum: -231100.467385\n",
+ "train end\n"
+ ]
+ }
+ ],
+ "source": [
+ "if task_name.startswith(\"train\"):\n",
+ " trainer.train_and_valid()\n",
+ " trainer.test()\n",
+ " print(\"train end\")\n",
+ "elif task_name.startswith(\"test\"):\n",
+ " trainer.test()\n",
+ " print(\"test end\")\n",
+ "elif task_name.startswith(\"style_test\"):\n",
+ " daily_return_list = []\n",
+ " for trainer in trainers:\n",
+ " daily_return_list.extend(trainer.test())\n",
+ " print('win rate is: ', sum(r > 0 for r in daily_return_list) / len(daily_return_list))\n",
+ " print(\"style test end\")\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "HFT",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.15"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tutorial/Tutorial2_EIIE.ipynb b/tutorial/Tutorial2_EIIE.ipynb
new file mode 100644
index 00000000..d79203f2
--- /dev/null
+++ b/tutorial/Tutorial2_EIIE.ipynb
@@ -0,0 +1,461 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 1: Import Packages\n",
+ "Modify the system path and load the corresponding packages and functions "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "from pathlib import Path\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "ROOT = str(Path(\"__file__\").resolve().parents[1])\n",
+ "sys.path.append(ROOT)\n",
+ "import torch\n",
+ "import argparse\n",
+ "import os.path as osp\n",
+ "from mmcv import Config\n",
+ "from trademaster.utils import replace_cfg_vals\n",
+ "from trademaster.nets.builder import build_net\n",
+ "from trademaster.environments.builder import build_environment\n",
+ "from trademaster.datasets.builder import build_dataset\n",
+ "from trademaster.agents.builder import build_agent\n",
+ "from trademaster.optimizers.builder import build_optimizer\n",
+ "from trademaster.losses.builder import build_loss\n",
+ "from trademaster.trainers.builder import build_trainer\n",
+ "from trademaster.transition.builder import build_transition"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 2: Load Configs\n",
+ "Load default config from the folder `configs/portfolio_management/portfolio_management_dj30_eiie_eiie_adam_mse.py`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n",
+ "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"portfolio_management\", \"portfolio_management_dj30_eiie_eiie_adam_mse.py\"),\n",
+ " help=\"download datasets config file path\")\n",
+ "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n",
+ "parser.add_argument(\"--test_style\", type=str, default='-1')\n",
+ "args = parser.parse_args([])\n",
+ "cfg = Config.fromfile(args.config)\n",
+ "task_name = args.task_name\n",
+ "\n",
+ "cfg = replace_cfg_vals(cfg)\n",
+ "# update test style\n",
+ "cfg.data.update({'test_style': args.test_style})\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/portfolio_management/portfolio_management_dj30_eiie_eiie_adam_mse.py): {'data': {'type': 'PortfolioManagementDataset', 'data_path': 'data/portfolio_management/dj30', 'train_path': 'data/portfolio_management/dj30/train.csv', 'valid_path': 'data/portfolio_management/dj30/valid.csv', 'test_path': 'data/portfolio_management/dj30/test.csv', 'tech_indicator_list': ['zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'length_day': 10, 'initial_amount': 100000, 'transaction_cost_pct': 0.001, 'test_style_path': 'data/portfolio_management/dj30/DJI_label_by_DJIindex_3_24_-0.25_0.25.csv', 'test_style': '-1'}, 'environment': {'type': 'PortfolioManagementEIIEEnvironment'}, 'agent': {'type': 'PortfolioManagementEIIE', 'memory_capacity': 1000, 'gamma': 0.99, 'policy_update_frequency': 500}, 'trainer': {'type': 'PortfolioManagementEIIETrainer', 'epochs': 10, 'work_dir': 'work_dir/portfolio_management_dj30_eiie_eiie_adam_mse', 'if_remove': True}, 'loss': {'type': 'MSELoss'}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'EIIEConv', 'input_dim': None, 'output_dim': 1, 'time_steps': 10, 'kernel_size': 3, 'dims': [32]}, 'cri': {'type': 'EIIECritic', 'input_dim': None, 'action_dim': None, 'output_dim': 1, 'time_steps': None, 'num_layers': 1, 'hidden_size': 32}, 'transition': {'type': 'Transition'}, 'task_name': 'portfolio_management', 'dataset_name': 'dj30', 'net_name': 'eiie', 'agent_name': 'eiie', 'optimizer_name': 'adam', 'loss_name': 'mse', 'work_dir': 'work_dir/portfolio_management_dj30_eiie_eiie_adam_mse'}"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cfg"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 3: Build Dataset\n",
+ "Build datasets from cfg defined above"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = build_dataset(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 4: Build Reinforcement Learning Environments\n",
+ "Build environments based on cfg and previously-defined dataset\n",
+ "\n",
+ "A style-test is provided as an option to test the algorithm's performance under different market conditions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"train\"))\n",
+ "valid_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"valid\"))\n",
+ "test_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"test\"))\n",
+ "if task_name.startswith(\"style_test\"):\n",
+ " test_style_environments = []\n",
+ " for i, path in enumerate(dataset.test_style_paths):\n",
+ " test_style_environments.append(build_environment(cfg, default_args=dict(dataset=dataset, task=\"test_style\",\n",
+ " style_test_path=path,\n",
+ " task_index=i)))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "valid_environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_environment"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 5: Build Net \n",
+ "Update information about the state and action dimension in the config and create nets and optimizer for EIIE\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "action_dim = train_environment.action_dim # 29\n",
+ "state_dim = train_environment.state_dim # 11\n",
+ "input_dim = len(train_environment.tech_indicator_list)\n",
+ "time_steps = train_environment.time_steps\n",
+ "\n",
+ "cfg.act.update(dict(input_dim=input_dim, time_steps=time_steps))\n",
+ "cfg.cri.update(dict(input_dim=input_dim, action_dim= action_dim, time_steps=time_steps))\n",
+ "\n",
+ "act = build_net(cfg.act)\n",
+ "cri = build_net(cfg.cri)\n",
+ "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n",
+ "cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 6: Build Loss\n",
+ "Build loss from config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "criterion = build_loss(cfg)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 7: Build Transition\n",
+ "Build transition from config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transition = build_transition(cfg)\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 8: Build Agent\n",
+ "Build agent from config and detect device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "agent = build_agent(cfg, default_args=dict(action_dim=action_dim,\n",
+ " state_dim=state_dim,\n",
+ " time_steps = time_steps,\n",
+ " act=act,\n",
+ " cri=cri,\n",
+ " act_optimizer=act_optimizer,\n",
+ " cri_optimizer = cri_optimizer,\n",
+ " criterion=criterion,\n",
+ " transition = transition,\n",
+ " device = device))"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 9: Build Trainer\n",
+ "Build trainer from config and create work directionary to save the result, model and config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "| Arguments Remove work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse\n"
+ ]
+ }
+ ],
+ "source": [
+ "if task_name.startswith(\"style_test\"):\n",
+ " trainers = []\n",
+ " for env in test_style_environments:\n",
+ " trainers.append(build_trainer(cfg, default_args=dict(train_environment=train_environment,\n",
+ " valid_environment=valid_environment,\n",
+ " test_environment=env,\n",
+ " agent=agent,\n",
+ " device=device)))\n",
+ "else:\n",
+ " trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n",
+ " valid_environment=valid_environment,\n",
+ " test_environment=test_environment,\n",
+ " agent=agent,\n",
+ " device=device,\n",
+ " ))\n",
+ "work_dir = os.path.join(ROOT, cfg.trainer.work_dir)\n",
+ "\n",
+ "if not os.path.exists(work_dir):\n",
+ " os.makedirs(work_dir)\n",
+ "cfg.dump(osp.join(work_dir, osp.basename(args.config)))"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 10: Train the Trainer\n",
+ "Train the trainer based on the config and get results from workdir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train Episode: [1/10]\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| 180.543176% | 0.001605 | 0.007575 | 0.645235 | 1.688276 | 4.176283 |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "Valid Episode: [1/10]\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| 9.519411% | 0.001804 | 0.021395 | 0.361807 | 0.406514 | 0.529803 |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "Valid Episode Reward Sum: 0.090932\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00001.pth\n",
+ "Train Episode: [2/10]\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| 180.534263% | 0.001605 | 0.007575 | 0.645216 | 1.688270 | 4.176311 |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "Valid Episode: [2/10]\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| 9.520920% | 0.001804 | 0.021394 | 0.361798 | 0.406550 | 0.529857 |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "Valid Episode Reward Sum: 0.090945\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00002.pth\n",
+ "Train Episode: [3/10]\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| 180.519361% | 0.001605 | 0.007574 | 0.645198 | 1.688229 | 4.176275 |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "Valid Episode: [3/10]\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "| 9.521613% | 0.001804 | 0.021394 | 0.361794 | 0.406567 | 0.529882 |\n",
+ "+---------------+-------------+------------+--------------+--------------+---------------+\n",
+ "Valid Episode Reward Sum: 0.090952\n",
+ "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00003.pth\n",
+ "Train Episode: [4/10]\n"
+ ]
+ }
+ ],
+ "source": [
+ "if task_name.startswith(\"train\"):\n",
+ " trainer.train_and_valid()\n",
+ " trainer.test()\n",
+ " print(\"train end\")\n",
+ "elif task_name.startswith(\"test\"):\n",
+ " trainer.test()\n",
+ " print(\"test end\")\n",
+ "elif task_name.startswith(\"style_test\"):\n",
+ " daily_return_list = []\n",
+ " for trainer in trainers:\n",
+ " daily_return_list.extend(trainer.test())\n",
+ " print('win rate is: ', sum(r > 0 for r in daily_return_list) / len(daily_return_list))\n",
+ " print(\"style test end\")\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "HFT",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.15"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tutorial/Tutorial5_HFT.ipynb b/tutorial/Tutorial5_HFT.ipynb
new file mode 100644
index 00000000..16e82ec7
--- /dev/null
+++ b/tutorial/Tutorial5_HFT.ipynb
@@ -0,0 +1,451 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 1: Import Packages\n",
+ "Modify the system path and load the corresponding packages and functions "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "from pathlib import Path\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "ROOT = str(Path(\"__file__\").resolve().parents[1])\n",
+ "sys.path.append(ROOT)\n",
+ "import torch\n",
+ "import argparse\n",
+ "import os.path as osp\n",
+ "from mmcv import Config\n",
+ "from trademaster.utils import replace_cfg_vals\n",
+ "from trademaster.nets.builder import build_net\n",
+ "from trademaster.environments.builder import build_environment\n",
+ "from trademaster.datasets.builder import build_dataset\n",
+ "from trademaster.agents.builder import build_agent\n",
+ "from trademaster.optimizers.builder import build_optimizer\n",
+ "from trademaster.losses.builder import build_loss\n",
+ "from trademaster.trainers.builder import build_trainer\n",
+ "from trademaster.transition.builder import build_transition"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 2: Load Configs\n",
+ "Load default config from the folder `configs/high_frequency_trading/high_frequency_trading_BTC_dqn_dqn_adam_mse.py`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n",
+ "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"high_frequency_trading\", \"high_frequency_trading_BTC_dqn_dqn_adam_mse.py\"),\n",
+ " help=\"download datasets config file path\")\n",
+ "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n",
+ "parser.add_argument(\"--test_style\", type=str, default='-1')\n",
+ "args = parser.parse_args([])\n",
+ "cfg = Config.fromfile(args.config)\n",
+ "task_name = args.task_name\n",
+ "\n",
+ "cfg = replace_cfg_vals(cfg)\n",
+ "# update test style\n",
+ "cfg.data.update({'test_style': args.test_style})\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/high_frequency_trading/high_frequency_trading_BTC_dqn_dqn_adam_mse.py): {'data': {'type': 'HighFrequencyTradingDataset', 'data_path': 'data/high_frequency_trading/small_BTC', 'train_path': 'data/high_frequency_trading/small_BTC/train.csv', 'valid_path': 'data/high_frequency_trading/small_BTC/valid.csv', 'test_path': 'data/high_frequency_trading/small_BTC/test.csv', 'test_style_path': 'data/high_frequency_trading/small_BTC/test.csv', 'tech_indicator_list': ['imblance_volume_oe', 'sell_spread_oe', 'buy_spread_oe', 'kmid2', 'bid1_size_n', 'ksft2', 'ma_10', 'ksft', 'kmid', 'ask1_size_n', 'trade_diff', 'qtlu_10', 'qtld_10', 'cntd_10', 'beta_10', 'roc_10', 'bid5_size_n', 'rsv_10', 'imxd_10', 'ask5_size_n', 'ma_30', 'max_10', 'qtlu_30', 'imax_10', 'imin_10', 'min_10', 'qtld_30', 'cntn_10', 'rsv_30', 'cntp_10', 'ma_60', 'max_30', 'qtlu_60', 'qtld_60', 'cntd_30', 'roc_30', 'beta_30', 'bid4_size_n', 'rsv_60', 'ask4_size_n', 'imxd_30', 'min_30', 'max_60', 'imax_30', 'imin_30', 'cntd_60', 'roc_60', 'beta_60', 'cntn_30', 'min_60', 'cntp_30', 'bid3_size_n', 'imxd_60', 'ask3_size_n', 'sell_volume_oe', 'imax_60', 'imin_60', 'cntn_60', 'buy_volume_oe', 'cntp_60', 'bid2_size_n', 'kup', 'bid1_size', 'ask1_size', 'std_30', 'ask2_size_n'], 'transcation_cost': 0, 'backward_num_timestamp': 1, 'max_holding_number': 0.01, 'num_action': 11, 'max_punish': 1000000000000.0, 'episode_length': 14400, 'test_style': '-1'}, 'environment': {'type': 'HighFrequencyTradingEnvironment'}, 'agent': {'type': 'HighFrequencyTradingDDQN', 'auxiliary_coffient': 512, 'reward_scale': 1, 'repeat_times': 1, 'gamma': 0.99, 'batch_size': 64, 'clip_grad_norm': 3.0, 'soft_update_tau': 0, 'state_value_tau': 0.005}, 'trainer': {'type': 'HighFrequencyTradingTrainer', 'epochs': 10, 'work_dir': 'work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse', 'seeds': 12345, 'batch_size': 512, 'horizon_len': 512, 'buffer_size': 100000.0, 'num_threads': 8, 'if_remove': False, 'if_discrete': True, 'if_off_policy': True, 'if_keep_save': True, 'if_over_write': False, 'if_save_buffer': False}, 'loss': {'type': 'HFTLoss', 'ada': 1}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'HFTQNet', 'state_dim': 66, 'action_dim': 11, 'dims': 16, 'explore_rate': 0.25, 'max_punish': 0}, 'cri': None, 'task_name': 'high_frequency_trading', 'dataset_name': 'BTC', 'optimizer_name': 'adam', 'loss_name': 'mse', 'auxiliry_loss_name': 'KLdiv', 'net_name': 'high_frequency_trading_dqn', 'agent_name': 'ddqn', 'work_dir': 'work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse', 'batch_size': 512, 'train_environment': {'type': 'HighFrequencyTradingTrainingEnvironment'}}"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cfg"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 3: Build Dataset\n",
+ "Build datasets from cfg defined above"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = build_dataset(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 4: Build Reinforcement Learning Environments\n",
+ "Build environments based on cfg and previously-defined dataset\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "valid\n",
+ "test\n"
+ ]
+ }
+ ],
+ "source": [
+ "valid_environment = build_environment(\n",
+ " cfg, default_args=dict(dataset=dataset, task=\"valid\")\n",
+ " )\n",
+ "test_environment = build_environment(\n",
+ " cfg, default_args=dict(dataset=dataset, task=\"test\")\n",
+ ")\n",
+ "cfg.environment = cfg.train_environment\n",
+ "train_environment = build_environment(\n",
+ " cfg, default_args=dict(dataset=dataset, task=\"train\")\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "valid_environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_environment"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 5: Build Net \n",
+ "Update information about the state and action dimension in the config and create nets and optimizer for DQN\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "action_dim = train_environment.action_dim\n",
+ "state_dim = train_environment.state_dim\n",
+ "\n",
+ "cfg.act.update(dict(action_dim=action_dim, state_dim=state_dim))\n",
+ "act = build_net(cfg.act)\n",
+ "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n",
+ "if cfg.cri:\n",
+ " cfg.cri.update(dict(action_dim=action_dim, state_dim=state_dim))\n",
+ " cri = build_net(cfg.cri)\n",
+ " cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))\n",
+ "else:\n",
+ " cri = None\n",
+ " cri_optimizer = None"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 6: Build Loss\n",
+ "Build loss from config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "criterion = build_loss(cfg)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 7: Build Agent\n",
+ "Build agent from config and detect device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "agent = build_agent(\n",
+ " cfg,\n",
+ " default_args=dict(\n",
+ " action_dim=action_dim,\n",
+ " state_dim=state_dim,\n",
+ " act=act,\n",
+ " cri=cri,\n",
+ " act_optimizer=act_optimizer,\n",
+ " cri_optimizer=cri_optimizer,\n",
+ " criterion=criterion,\n",
+ " device=device,\n",
+ " ),\n",
+ " )\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 8: Build Trainer\n",
+ "Build trainer from config and create work directionary to save the result, model and config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "| Arguments Keep work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n",
+ " valid_environment=valid_environment,\n",
+ " test_environment=test_environment,\n",
+ " agent=agent,\n",
+ " device=device))\n",
+ "\n",
+ "cfg.dump(osp.join(ROOT, cfg.work_dir, osp.basename(args.config)))"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Step 9: Train the Trainer\n",
+ "Train the trainer based on the config and get results from workdir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "Train Episode: [1/10]\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be sell all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n",
+ "the holding could not be bought all clear\n"
+ ]
+ }
+ ],
+ "source": [
+ "if task_name.startswith(\"train\"):\n",
+ " trainer.train_and_valid()\n",
+ " trainer.test()\n",
+ " print(\"train end\")\n",
+ "elif task_name.startswith(\"test\"):\n",
+ " trainer.test()\n",
+ " print(\"test end\")\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "HFT",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.15"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}