From ec3bc0bda366ac6eec4945812dc977aa5036a916 Mon Sep 17 00:00:00 2001 From: qml <937883273@qq.com> Date: Mon, 20 Feb 2023 20:29:00 +0800 Subject: [PATCH] fix tutorial --- docs/source/tutorial/tutorial1.md | 22 +- docs/source/tutorial/tutorial2.md | 23 +- docs/source/tutorial/tutorial5.md | 15 +- tutorial/Tutorial1_DeepScalper.ipynb | 704 +++++++++++++++++++++++++++ tutorial/Tutorial2_EIIE.ipynb | 461 ++++++++++++++++++ tutorial/Tutorial5_HFT.ipynb | 451 +++++++++++++++++ 6 files changed, 1647 insertions(+), 29 deletions(-) create mode 100644 tutorial/Tutorial1_DeepScalper.ipynb create mode 100644 tutorial/Tutorial2_EIIE.ipynb create mode 100644 tutorial/Tutorial5_HFT.ipynb diff --git a/docs/source/tutorial/tutorial1.md b/docs/source/tutorial/tutorial1.md index cdfd3aeb..a9f257a0 100644 --- a/docs/source/tutorial/tutorial1.md +++ b/docs/source/tutorial/tutorial1.md @@ -1,20 +1,18 @@ # Tutorial 1: Intraday Crypto Trading with DeepScalper +![DeepScalper.png](DeepScalper.png) -## Task Intraday trading is a fundamental quantitative trading task, where traders actively long/short one pre-selected financial asset within the same trading day to maximize future profit. -## Algorithm -DeepScalper contains 4 technical contributions which all together make it better than direct use of RL algorithms. -- RL optimization with action branching -- reward function with hindsight bonus -- intraday market embedding -- risk-aware auxiliary task +DeepScalper use deep q network to optimize the reward sum got from reinforcement learning where a hindsight reward is used to capture the long-term porfit trends and embedding from both micro-level and macro-level market information. -Here is the construction of the DeepScalper: -
- -
-Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/DeepScalper.ipynb) about how you can build DeepScalper in a few lines of codes using TradeMaster. \ No newline at end of file +## Notebook and Script +In this notebook, we implement the training and testing process of DeepScalper based on the TradeMaster framework. + +[Tutorial1_DeepScalper](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial1_DeepScalper.ipynb) + +And this is the script for training and testing. + +[train.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/algorithmic_trading/train.py) \ No newline at end of file diff --git a/docs/source/tutorial/tutorial2.md b/docs/source/tutorial/tutorial2.md index eaba8091..3a6a81f5 100644 --- a/docs/source/tutorial/tutorial2.md +++ b/docs/source/tutorial/tutorial2.md @@ -1,19 +1,20 @@ # Tutorial 2: Portfolio Management with EIIE on US stocks - -## Task +![EIIE.png](EIIE.png) Portfolio management is the action of continuous reallocation of a capital into a number of financial assets periodically. -## Algorithm -EIIE contains 2 technical contributions which all together make it better than direct use of RL algorithms. -- Deterministic Policy Gradient -- Portfolio-Vector Memor +The framework consists of the Ensemble of Identical Independent Evaluators +(EIIE) topology, a Portfolio-Vector Memory (PVM), an Online Stochastic Batch Learning +(OSBL) scheme, and a fully exploiting and explicit reward function. + + + -Here is the construction of the EIIE: -
- -
+## Notebook and Script +In this notebook, we implement the training and testing process of EIIE based on the TradeMaster framework. +[Tutorial2_EIIE](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial2_EIIE.ipynb) +And this is the script for training and testing. -Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/EIIE.ipynb) about how you can build EIIE in a few lines of codes using TradeMaster. \ No newline at end of file +[train_eiie.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/portfolio_management/train_eiie.py) \ No newline at end of file diff --git a/docs/source/tutorial/tutorial5.md b/docs/source/tutorial/tutorial5.md index 7ee6a7fb..6fa1f6ce 100644 --- a/docs/source/tutorial/tutorial5.md +++ b/docs/source/tutorial/tutorial5.md @@ -1,11 +1,14 @@ # Tutorial 5: High Frequency Trading with Double DQN -## Task High Frequency Trading is a fundamental quantitative trading task, where traders actively buy/sell one pre-selected financial periodically in seconds with the consideration of order execution. -## Algorithm -HFT_DDQN contains 2 technical contributions which all together make it better than direct use of RL algorithms. -- Double deep q network -- Regulator from the true q table +HFT_DDQN use a decayed supervised regulator genereated from the real q table based on the future price information and a double q network to optimizer the portfit margine. -Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/HFT.ipynb) about how you can build DDQN for HFT in a few lines of codes using TradeMaster. \ No newline at end of file +## Notebook and Script +In this notebook, we implement the training and testing process of HFTDDQN based on the TradeMaster framework. + +[Tutorial5_HFT](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial5_HFT.ipynb) + +And this is the script for training and testing. + +[train.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/high_frequency_trading/train.py) \ No newline at end of file diff --git a/tutorial/Tutorial1_DeepScalper.ipynb b/tutorial/Tutorial1_DeepScalper.ipynb new file mode 100644 index 00000000..816dbb4a --- /dev/null +++ b/tutorial/Tutorial1_DeepScalper.ipynb @@ -0,0 +1,704 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Import Packages\n", + "Modify the system path and load the corresponding packages and functions " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "ROOT = str(Path(\"__file__\").resolve().parents[1])\n", + "sys.path.append(ROOT)\n", + "import torch\n", + "import argparse\n", + "import os.path as osp\n", + "from mmcv import Config\n", + "from trademaster.utils import replace_cfg_vals\n", + "from trademaster.nets.builder import build_net\n", + "from trademaster.environments.builder import build_environment\n", + "from trademaster.datasets.builder import build_dataset\n", + "from trademaster.agents.builder import build_agent\n", + "from trademaster.optimizers.builder import build_optimizer\n", + "from trademaster.losses.builder import build_loss\n", + "from trademaster.trainers.builder import build_trainer\n", + "from trademaster.transition.builder import build_transition" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Load Configs\n", + "Load default config from the folder `configs/algorithmic_trading/algorithmic_trading_BTC_dqn_dqn_adam_mse.py`" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n", + "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"algorithmic_trading\", \"algorithmic_trading_BTC_dqn_dqn_adam_mse.py\"),\n", + " help=\"download datasets config file path\")\n", + "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n", + "parser.add_argument(\"--test_style\", type=str, default='-1')\n", + "args = parser.parse_args([])\n", + "cfg = Config.fromfile(args.config)\n", + "task_name = args.task_name\n", + "\n", + "cfg = replace_cfg_vals(cfg)\n", + "# update test style\n", + "cfg.data.update({'test_style': args.test_style})\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/algorithmic_trading/algorithmic_trading_BTC_dqn_dqn_adam_mse.py): {'data': {'type': 'AlgorithmicTradingDataset', 'data_path': 'data/algorithmic_trading/BTC', 'train_path': 'data/algorithmic_trading/BTC/train.csv', 'valid_path': 'data/algorithmic_trading/BTC/valid.csv', 'test_path': 'data/algorithmic_trading/BTC/test.csv', 'test_style_path': 'data/algorithmic_trading/BTC/test_labeled_3_24_-0.15_0.15.csv', 'tech_indicator_list': ['high', 'low', 'open', 'close', 'adjcp', 'zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'backward_num_day': 5, 'forward_num_day': 5, 'test_style': '-1'}, 'environment': {'type': 'AlgorithmicTradingEnvironment'}, 'agent': {'type': 'AlgorithmicTradingDQN', 'max_step': 12345, 'reward_scale': 1, 'repeat_times': 1, 'gamma': 0.9, 'batch_size': 64, 'clip_grad_norm': 3.0, 'soft_update_tau': 0, 'state_value_tau': 0.005}, 'trainer': {'type': 'AlgorithmicTradingTrainer', 'epochs': 20, 'work_dir': 'work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse', 'seeds_list': (12345,), 'batch_size': 64, 'horizon_len': 1024, 'buffer_size': 1000000.0, 'num_threads': 8, 'if_remove': False, 'if_discrete': True, 'if_off_policy': True, 'if_keep_save': True, 'if_over_write': False, 'if_save_buffer': False}, 'loss': {'type': 'MSELoss'}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'QNet', 'state_dim': 82, 'action_dim': 3, 'dims': (64, 32), 'explore_rate': 0.25}, 'cri': None, 'transition': {'type': 'Transition'}, 'task_name': 'algorithmic_trading', 'dataset_name': 'BTC', 'optimizer_name': 'adam', 'loss_name': 'mse', 'net_name': 'dqn', 'agent_name': 'dqn', 'work_dir': 'work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse', 'batch_size': 64}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Build Dataset\n", + "Build datasets from cfg defined above" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Build Reinforcement Learning Environments\n", + "Build environments based on cfg and previously-defined dataset\n", + "\n", + "A style-test is provided as an option to test the algorithm's performance under different market conditions" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "train_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"train\"))\n", + "valid_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"valid\"))\n", + "test_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"test\"))\n", + "if task_name.startswith(\"style_test\"):\n", + " test_style_environments = []\n", + " for i, path in enumerate(dataset.test_style_paths):\n", + " test_style_environments.append(build_environment(cfg, default_args=dict(dataset=dataset, task=\"test_style\",\n", + " style_test_path=path,\n", + " task_index=i)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valid_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_environment" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Build Net \n", + "Update information about the state and action dimension in the config and create nets and optimizer for DQN\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "action_dim = train_environment.action_dim\n", + "state_dim = train_environment.state_dim\n", + "\n", + "cfg.act.update(dict(action_dim=action_dim, state_dim=state_dim))\n", + "act = build_net(cfg.act)\n", + "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n", + "if cfg.cri:\n", + " cfg.cri.update(dict(action_dim=action_dim, state_dim=state_dim))\n", + " cri = build_net(cfg.cri)\n", + " cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))\n", + "else:\n", + " cri = None\n", + " cri_optimizer = None" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Build Loss\n", + "Build loss from config" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = build_loss(cfg)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Build Transition\n", + "Build transition from config" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "transition = build_transition(cfg)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Build Agent\n", + "Build agent from config and detect device" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "agent = build_agent(cfg, default_args=dict(action_dim = action_dim,\n", + " state_dim = state_dim,\n", + " act = act,\n", + " cri = cri,\n", + " act_optimizer = act_optimizer,\n", + " cri_optimizer = cri_optimizer,\n", + " criterion = criterion,\n", + " transition = transition,\n", + " device=device))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9: Build Trainer\n", + "Build trainer from config and create work directionary to save the result, model and config" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| Arguments Keep work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"style_test\"):\n", + " trainers = []\n", + " for env in test_style_environments:\n", + " trainers.append(build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=env,\n", + " agent=agent,\n", + " device=device)))\n", + "else:\n", + " trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=test_environment,\n", + " agent=agent,\n", + " device=device))\n", + "\n", + "cfg.dump(osp.join(ROOT, cfg.work_dir, osp.basename(args.config)))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 10: Train the Trainer\n", + "Train the trainer based on the config and get results from workdir" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Episode: [1/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -344.513937% | -0.000214 | 6862.304186 | 6.214777 | -27423.674698 | -0.556031 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode: [1/20]\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "| 36.243173% | 0.001800 | 4108.651796 | 0.583186 | 62313.081138 | 0.551188 |\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00001.pth\n", + "Train Episode: [2/20]\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| 2550.222626% | 0.000294 | 74654.187304 | 0.984371 | 2592355.042556 | 0.737448 |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "Valid Episode: [2/20]\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "| 38.089801% | 0.002813 | 2762.654677 | 0.456196 | 83702.883876 | 0.945576 |\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 57332.169254\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00002.pth\n", + "Train Episode: [3/20]\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| 2129.816364% | 0.000301 | 60992.463565 | 0.990535 | 2151453.831529 | 0.761704 |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "Valid Episode: [3/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00003.pth\n", + "Train Episode: [4/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2334.279839% | -0.000274 | 62958.198404 | 24.260127 | -82729.892836 | -0.704792 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [4/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.064518% | -0.001804 | 4108.595508 | 2.561829 | -14212.530944 | -0.820658 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56365.177630\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00004.pth\n", + "Train Episode: [5/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1641.409323% | -0.000287 | 43526.174039 | 14.819907 | -97857.697010 | -0.740099 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [5/20]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| -0.000000% | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0nan |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 0.000000\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00005.pth\n", + "Train Episode: [6/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1691.332913% | -0.000290 | 46229.282308 | 36.828678 | -42306.683350 | -0.752619 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [6/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -227.824723% | -0.001990 | 3952.099453 | 2.539431 | -15219.010810 | -0.887165 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -60502.612906\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00006.pth\n", + "Train Episode: [7/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2174.481852% | -0.000292 | 59661.553550 | 47.345066 | -42758.782985 | -0.759986 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [7/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -176.839873% | -0.001765 | 3528.124043 | 2.195090 | -13939.332801 | -0.796095 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -52680.509863\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00007.pth\n", + "Train Episode: [8/20]\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| 1885.983659% | 0.000443 | 36656.692362 | 0.972584 | 1941130.966521 | 1.154473 |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "Valid Episode: [8/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00008.pth\n", + "Train Episode: [9/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2724.499907% | -0.000291 | 74498.020183 | 52.412219 | -48016.107359 | -0.744638 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [9/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00009.pth\n", + "Train Episode: [10/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2031.685492% | -0.000290 | 54747.564770 | 43.636693 | -42363.764305 | -0.757473 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [10/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00010.pth\n", + "Train Episode: [11/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2399.121800% | -0.000292 | 64957.047457 | 51.684147 | -42618.599118 | -0.761166 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [11/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.053823% | -0.001886 | 4043.229698 | 2.569442 | -14579.510155 | -0.866852 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -57221.583660\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00011.pth\n", + "Train Episode: [12/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1481.424033% | -0.000284 | 39030.031859 | 31.315723 | -41117.470365 | -0.737378 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [12/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00012.pth\n", + "Train Episode: [13/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -788.802406% | -0.000152 | 19854.508607 | 6.946587 | -50441.493440 | -0.380641 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [13/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -170.396222% | -0.002120 | 3555.765470 | 2.254471 | -16430.348220 | -0.931926 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -60797.589909\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00013.pth\n", + "Train Episode: [14/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1499.160477% | -0.000287 | 39627.327133 | 30.212022 | -43745.116737 | -0.735405 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [14/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -169.507466% | -0.002174 | 3551.437396 | 2.245583 | -16891.155853 | -0.954286 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -61053.498659\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00014.pth\n", + "Train Episode: [15/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -688.419270% | 0.000683 | 12404.350216 | 2.003406 | 491928.343157 | 1.901195 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [15/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -170.679616% | -0.002148 | 3483.353434 | 2.133488 | -17229.228497 | -0.957504 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -60056.333731\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00015.pth\n", + "Train Episode: [16/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -972.226134% | 0.000515 | 15091.307919 | 2.283666 | 395975.151491 | 1.384039 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [16/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00016.pth\n", + "Train Episode: [17/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1400.570121% | -0.000281 | 36908.469387 | 29.389929 | -41092.361952 | -0.717731 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [17/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00017.pth\n", + "Train Episode: [18/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -744.326652% | 0.000579 | 12140.578923 | 2.151154 | 379703.556119 | 1.557721 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [18/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -180.970136% | -0.001587 | 3330.163550 | 2.008192 | -12932.140987 | -0.734850 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -41071.900165\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00018.pth\n", + "Train Episode: [19/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -948.403058% | -0.000274 | 24082.344472 | 19.543640 | -39251.932592 | -0.693784 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [19/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00019.pth\n", + "Train Episode: [20/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1874.527862% | -0.000289 | 51213.539849 | 40.429782 | -42627.916754 | -0.738897 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [20/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -133.446994% | -0.002186 | 3066.785470 | 1.822390 | -18069.448586 | -0.946982 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -54479.743152\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00020.pth\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/best.pth\n", + "Resume checkpoint /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/best.pth\n", + "Test Best Episode\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| 49.996525% | -0.005364 | 4336.971275 | 1.165761 | -98033.866998 | -1.900394 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Test Best Episode Reward Sum: -231100.467385\n", + "train end\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"train\"):\n", + " trainer.train_and_valid()\n", + " trainer.test()\n", + " print(\"train end\")\n", + "elif task_name.startswith(\"test\"):\n", + " trainer.test()\n", + " print(\"test end\")\n", + "elif task_name.startswith(\"style_test\"):\n", + " daily_return_list = []\n", + " for trainer in trainers:\n", + " daily_return_list.extend(trainer.test())\n", + " print('win rate is: ', sum(r > 0 for r in daily_return_list) / len(daily_return_list))\n", + " print(\"style test end\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "HFT", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.15" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/Tutorial2_EIIE.ipynb b/tutorial/Tutorial2_EIIE.ipynb new file mode 100644 index 00000000..d79203f2 --- /dev/null +++ b/tutorial/Tutorial2_EIIE.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Import Packages\n", + "Modify the system path and load the corresponding packages and functions " + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "ROOT = str(Path(\"__file__\").resolve().parents[1])\n", + "sys.path.append(ROOT)\n", + "import torch\n", + "import argparse\n", + "import os.path as osp\n", + "from mmcv import Config\n", + "from trademaster.utils import replace_cfg_vals\n", + "from trademaster.nets.builder import build_net\n", + "from trademaster.environments.builder import build_environment\n", + "from trademaster.datasets.builder import build_dataset\n", + "from trademaster.agents.builder import build_agent\n", + "from trademaster.optimizers.builder import build_optimizer\n", + "from trademaster.losses.builder import build_loss\n", + "from trademaster.trainers.builder import build_trainer\n", + "from trademaster.transition.builder import build_transition" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Load Configs\n", + "Load default config from the folder `configs/portfolio_management/portfolio_management_dj30_eiie_eiie_adam_mse.py`" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n", + "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"portfolio_management\", \"portfolio_management_dj30_eiie_eiie_adam_mse.py\"),\n", + " help=\"download datasets config file path\")\n", + "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n", + "parser.add_argument(\"--test_style\", type=str, default='-1')\n", + "args = parser.parse_args([])\n", + "cfg = Config.fromfile(args.config)\n", + "task_name = args.task_name\n", + "\n", + "cfg = replace_cfg_vals(cfg)\n", + "# update test style\n", + "cfg.data.update({'test_style': args.test_style})\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/portfolio_management/portfolio_management_dj30_eiie_eiie_adam_mse.py): {'data': {'type': 'PortfolioManagementDataset', 'data_path': 'data/portfolio_management/dj30', 'train_path': 'data/portfolio_management/dj30/train.csv', 'valid_path': 'data/portfolio_management/dj30/valid.csv', 'test_path': 'data/portfolio_management/dj30/test.csv', 'tech_indicator_list': ['zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'length_day': 10, 'initial_amount': 100000, 'transaction_cost_pct': 0.001, 'test_style_path': 'data/portfolio_management/dj30/DJI_label_by_DJIindex_3_24_-0.25_0.25.csv', 'test_style': '-1'}, 'environment': {'type': 'PortfolioManagementEIIEEnvironment'}, 'agent': {'type': 'PortfolioManagementEIIE', 'memory_capacity': 1000, 'gamma': 0.99, 'policy_update_frequency': 500}, 'trainer': {'type': 'PortfolioManagementEIIETrainer', 'epochs': 10, 'work_dir': 'work_dir/portfolio_management_dj30_eiie_eiie_adam_mse', 'if_remove': True}, 'loss': {'type': 'MSELoss'}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'EIIEConv', 'input_dim': None, 'output_dim': 1, 'time_steps': 10, 'kernel_size': 3, 'dims': [32]}, 'cri': {'type': 'EIIECritic', 'input_dim': None, 'action_dim': None, 'output_dim': 1, 'time_steps': None, 'num_layers': 1, 'hidden_size': 32}, 'transition': {'type': 'Transition'}, 'task_name': 'portfolio_management', 'dataset_name': 'dj30', 'net_name': 'eiie', 'agent_name': 'eiie', 'optimizer_name': 'adam', 'loss_name': 'mse', 'work_dir': 'work_dir/portfolio_management_dj30_eiie_eiie_adam_mse'}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Build Dataset\n", + "Build datasets from cfg defined above" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Build Reinforcement Learning Environments\n", + "Build environments based on cfg and previously-defined dataset\n", + "\n", + "A style-test is provided as an option to test the algorithm's performance under different market conditions" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "train_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"train\"))\n", + "valid_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"valid\"))\n", + "test_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"test\"))\n", + "if task_name.startswith(\"style_test\"):\n", + " test_style_environments = []\n", + " for i, path in enumerate(dataset.test_style_paths):\n", + " test_style_environments.append(build_environment(cfg, default_args=dict(dataset=dataset, task=\"test_style\",\n", + " style_test_path=path,\n", + " task_index=i)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valid_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_environment" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Build Net \n", + "Update information about the state and action dimension in the config and create nets and optimizer for EIIE\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "action_dim = train_environment.action_dim # 29\n", + "state_dim = train_environment.state_dim # 11\n", + "input_dim = len(train_environment.tech_indicator_list)\n", + "time_steps = train_environment.time_steps\n", + "\n", + "cfg.act.update(dict(input_dim=input_dim, time_steps=time_steps))\n", + "cfg.cri.update(dict(input_dim=input_dim, action_dim= action_dim, time_steps=time_steps))\n", + "\n", + "act = build_net(cfg.act)\n", + "cri = build_net(cfg.cri)\n", + "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n", + "cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Build Loss\n", + "Build loss from config" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = build_loss(cfg)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Build Transition\n", + "Build transition from config" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "transition = build_transition(cfg)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Build Agent\n", + "Build agent from config and detect device" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "agent = build_agent(cfg, default_args=dict(action_dim=action_dim,\n", + " state_dim=state_dim,\n", + " time_steps = time_steps,\n", + " act=act,\n", + " cri=cri,\n", + " act_optimizer=act_optimizer,\n", + " cri_optimizer = cri_optimizer,\n", + " criterion=criterion,\n", + " transition = transition,\n", + " device = device))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9: Build Trainer\n", + "Build trainer from config and create work directionary to save the result, model and config" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| Arguments Remove work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"style_test\"):\n", + " trainers = []\n", + " for env in test_style_environments:\n", + " trainers.append(build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=env,\n", + " agent=agent,\n", + " device=device)))\n", + "else:\n", + " trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=test_environment,\n", + " agent=agent,\n", + " device=device,\n", + " ))\n", + "work_dir = os.path.join(ROOT, cfg.trainer.work_dir)\n", + "\n", + "if not os.path.exists(work_dir):\n", + " os.makedirs(work_dir)\n", + "cfg.dump(osp.join(work_dir, osp.basename(args.config)))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 10: Train the Trainer\n", + "Train the trainer based on the config and get results from workdir" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Episode: [1/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 180.543176% | 0.001605 | 0.007575 | 0.645235 | 1.688276 | 4.176283 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode: [1/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 9.519411% | 0.001804 | 0.021395 | 0.361807 | 0.406514 | 0.529803 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 0.090932\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00001.pth\n", + "Train Episode: [2/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 180.534263% | 0.001605 | 0.007575 | 0.645216 | 1.688270 | 4.176311 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode: [2/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 9.520920% | 0.001804 | 0.021394 | 0.361798 | 0.406550 | 0.529857 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 0.090945\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00002.pth\n", + "Train Episode: [3/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 180.519361% | 0.001605 | 0.007574 | 0.645198 | 1.688229 | 4.176275 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode: [3/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 9.521613% | 0.001804 | 0.021394 | 0.361794 | 0.406567 | 0.529882 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 0.090952\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00003.pth\n", + "Train Episode: [4/10]\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"train\"):\n", + " trainer.train_and_valid()\n", + " trainer.test()\n", + " print(\"train end\")\n", + "elif task_name.startswith(\"test\"):\n", + " trainer.test()\n", + " print(\"test end\")\n", + "elif task_name.startswith(\"style_test\"):\n", + " daily_return_list = []\n", + " for trainer in trainers:\n", + " daily_return_list.extend(trainer.test())\n", + " print('win rate is: ', sum(r > 0 for r in daily_return_list) / len(daily_return_list))\n", + " print(\"style test end\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "HFT", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.15" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/Tutorial5_HFT.ipynb b/tutorial/Tutorial5_HFT.ipynb new file mode 100644 index 00000000..16e82ec7 --- /dev/null +++ b/tutorial/Tutorial5_HFT.ipynb @@ -0,0 +1,451 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Import Packages\n", + "Modify the system path and load the corresponding packages and functions " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "ROOT = str(Path(\"__file__\").resolve().parents[1])\n", + "sys.path.append(ROOT)\n", + "import torch\n", + "import argparse\n", + "import os.path as osp\n", + "from mmcv import Config\n", + "from trademaster.utils import replace_cfg_vals\n", + "from trademaster.nets.builder import build_net\n", + "from trademaster.environments.builder import build_environment\n", + "from trademaster.datasets.builder import build_dataset\n", + "from trademaster.agents.builder import build_agent\n", + "from trademaster.optimizers.builder import build_optimizer\n", + "from trademaster.losses.builder import build_loss\n", + "from trademaster.trainers.builder import build_trainer\n", + "from trademaster.transition.builder import build_transition" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Load Configs\n", + "Load default config from the folder `configs/high_frequency_trading/high_frequency_trading_BTC_dqn_dqn_adam_mse.py`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n", + "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"high_frequency_trading\", \"high_frequency_trading_BTC_dqn_dqn_adam_mse.py\"),\n", + " help=\"download datasets config file path\")\n", + "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n", + "parser.add_argument(\"--test_style\", type=str, default='-1')\n", + "args = parser.parse_args([])\n", + "cfg = Config.fromfile(args.config)\n", + "task_name = args.task_name\n", + "\n", + "cfg = replace_cfg_vals(cfg)\n", + "# update test style\n", + "cfg.data.update({'test_style': args.test_style})\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/high_frequency_trading/high_frequency_trading_BTC_dqn_dqn_adam_mse.py): {'data': {'type': 'HighFrequencyTradingDataset', 'data_path': 'data/high_frequency_trading/small_BTC', 'train_path': 'data/high_frequency_trading/small_BTC/train.csv', 'valid_path': 'data/high_frequency_trading/small_BTC/valid.csv', 'test_path': 'data/high_frequency_trading/small_BTC/test.csv', 'test_style_path': 'data/high_frequency_trading/small_BTC/test.csv', 'tech_indicator_list': ['imblance_volume_oe', 'sell_spread_oe', 'buy_spread_oe', 'kmid2', 'bid1_size_n', 'ksft2', 'ma_10', 'ksft', 'kmid', 'ask1_size_n', 'trade_diff', 'qtlu_10', 'qtld_10', 'cntd_10', 'beta_10', 'roc_10', 'bid5_size_n', 'rsv_10', 'imxd_10', 'ask5_size_n', 'ma_30', 'max_10', 'qtlu_30', 'imax_10', 'imin_10', 'min_10', 'qtld_30', 'cntn_10', 'rsv_30', 'cntp_10', 'ma_60', 'max_30', 'qtlu_60', 'qtld_60', 'cntd_30', 'roc_30', 'beta_30', 'bid4_size_n', 'rsv_60', 'ask4_size_n', 'imxd_30', 'min_30', 'max_60', 'imax_30', 'imin_30', 'cntd_60', 'roc_60', 'beta_60', 'cntn_30', 'min_60', 'cntp_30', 'bid3_size_n', 'imxd_60', 'ask3_size_n', 'sell_volume_oe', 'imax_60', 'imin_60', 'cntn_60', 'buy_volume_oe', 'cntp_60', 'bid2_size_n', 'kup', 'bid1_size', 'ask1_size', 'std_30', 'ask2_size_n'], 'transcation_cost': 0, 'backward_num_timestamp': 1, 'max_holding_number': 0.01, 'num_action': 11, 'max_punish': 1000000000000.0, 'episode_length': 14400, 'test_style': '-1'}, 'environment': {'type': 'HighFrequencyTradingEnvironment'}, 'agent': {'type': 'HighFrequencyTradingDDQN', 'auxiliary_coffient': 512, 'reward_scale': 1, 'repeat_times': 1, 'gamma': 0.99, 'batch_size': 64, 'clip_grad_norm': 3.0, 'soft_update_tau': 0, 'state_value_tau': 0.005}, 'trainer': {'type': 'HighFrequencyTradingTrainer', 'epochs': 10, 'work_dir': 'work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse', 'seeds': 12345, 'batch_size': 512, 'horizon_len': 512, 'buffer_size': 100000.0, 'num_threads': 8, 'if_remove': False, 'if_discrete': True, 'if_off_policy': True, 'if_keep_save': True, 'if_over_write': False, 'if_save_buffer': False}, 'loss': {'type': 'HFTLoss', 'ada': 1}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'HFTQNet', 'state_dim': 66, 'action_dim': 11, 'dims': 16, 'explore_rate': 0.25, 'max_punish': 0}, 'cri': None, 'task_name': 'high_frequency_trading', 'dataset_name': 'BTC', 'optimizer_name': 'adam', 'loss_name': 'mse', 'auxiliry_loss_name': 'KLdiv', 'net_name': 'high_frequency_trading_dqn', 'agent_name': 'ddqn', 'work_dir': 'work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse', 'batch_size': 512, 'train_environment': {'type': 'HighFrequencyTradingTrainingEnvironment'}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Build Dataset\n", + "Build datasets from cfg defined above" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Build Reinforcement Learning Environments\n", + "Build environments based on cfg and previously-defined dataset\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "valid\n", + "test\n" + ] + } + ], + "source": [ + "valid_environment = build_environment(\n", + " cfg, default_args=dict(dataset=dataset, task=\"valid\")\n", + " )\n", + "test_environment = build_environment(\n", + " cfg, default_args=dict(dataset=dataset, task=\"test\")\n", + ")\n", + "cfg.environment = cfg.train_environment\n", + "train_environment = build_environment(\n", + " cfg, default_args=dict(dataset=dataset, task=\"train\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valid_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_environment" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Build Net \n", + "Update information about the state and action dimension in the config and create nets and optimizer for DQN\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "action_dim = train_environment.action_dim\n", + "state_dim = train_environment.state_dim\n", + "\n", + "cfg.act.update(dict(action_dim=action_dim, state_dim=state_dim))\n", + "act = build_net(cfg.act)\n", + "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n", + "if cfg.cri:\n", + " cfg.cri.update(dict(action_dim=action_dim, state_dim=state_dim))\n", + " cri = build_net(cfg.cri)\n", + " cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))\n", + "else:\n", + " cri = None\n", + " cri_optimizer = None" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Build Loss\n", + "Build loss from config" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = build_loss(cfg)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Build Agent\n", + "Build agent from config and detect device" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "agent = build_agent(\n", + " cfg,\n", + " default_args=dict(\n", + " action_dim=action_dim,\n", + " state_dim=state_dim,\n", + " act=act,\n", + " cri=cri,\n", + " act_optimizer=act_optimizer,\n", + " cri_optimizer=cri_optimizer,\n", + " criterion=criterion,\n", + " device=device,\n", + " ),\n", + " )\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Build Trainer\n", + "Build trainer from config and create work directionary to save the result, model and config" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| Arguments Keep work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse\n" + ] + } + ], + "source": [ + "\n", + "trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=test_environment,\n", + " agent=agent,\n", + " device=device))\n", + "\n", + "cfg.dump(osp.join(ROOT, cfg.work_dir, osp.basename(args.config)))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9: Train the Trainer\n", + "Train the trainer based on the config and get results from workdir" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "Train Episode: [1/10]\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"train\"):\n", + " trainer.train_and_valid()\n", + " trainer.test()\n", + " print(\"train end\")\n", + "elif task_name.startswith(\"test\"):\n", + " trainer.test()\n", + " print(\"test end\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "HFT", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.15" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}