# SARL for Portfolio Management on DJ30
This tutorial is to demonstrate an example of using SARL to do portfolio management on DJ30

## Step1: Import Packages

In [1]:
import warnings

warnings.filterwarnings("ignore")
import sys
from pathlib import Path
import os
import torch

ROOT = os.path.dirname(os.path.abspath("."))
sys.path.append(ROOT)
from trademaster.utils import plot
import argparse
import os.path as osp
from mmcv import Config
from trademaster.utils import replace_cfg_vals
from trademaster.datasets.builder import build_dataset
from trademaster.trainers.builder import build_trainer

2023-02-28 14:34:23,739	INFO worker.py:973 -- Calling ray.init() again after it has already been called.


## Step2: Import Configs

In [2]:

parser = argparse.ArgumentParser(description='Download Alpaca Datasets')
parser.add_argument("--config", default=osp.join(ROOT, "configs", "portfolio_management", "portfolio_management_dj30_sarl_sarl_adam_mse.py"),
                    help="download datasets config file path")
parser.add_argument("--task_name", type=str, default="train")
args, _ = parser.parse_known_args()

cfg = Config.fromfile(args.config)
task_name = args.task_name
cfg = replace_cfg_vals(cfg)

In [3]:
cfg

Config (path: D:\pycharm_workspace\TradeMaster\configs\portfolio_management\portfolio_management_dj30_sarl_sarl_adam_mse.py): {'data': {'type': 'PortfolioManagementDataset', 'data_path': 'data/portfolio_management/dj30', 'train_path': 'data/portfolio_management/dj30/train.csv', 'valid_path': 'data/portfolio_management/dj30/valid.csv', 'test_path': 'data/portfolio_management/dj30/test.csv', 'tech_indicator_list': ['high', 'low', 'open', 'close', 'adjcp', 'zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'length_day': 5, 'initial_amount': 10000, 'transaction_cost_pct': 0.001, 'test_dynamic_path': 'data/portfolio_management/dj30/DJI_label_by_DJIindex_3_24_-0.25_0.25.csv'}, 'environment': {'type': 'PortfolioManagementSARLEnvironment'}, 'trainer': {'type': 'PortfolioManagementSARLTrainer', 'agent_name': 'ddpg', 'if_remove': False, 'configs': {'framework': 'tf2', 'num_workers': 0}, 'work_dir': 'work_dir/portfolio_management_dj30_sarl_sarl_ada

## Step3: Build Dataset

In [4]:
dataset = build_dataset(cfg)

## Step4: Build Trainer

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
work_dir = os.path.join(ROOT, cfg.trainer.work_dir)

if not os.path.exists(work_dir):
    os.makedirs(work_dir)
cfg.dump(osp.join(work_dir, osp.basename(args.config)))

trainer = build_trainer(cfg, default_args=dict(dataset=dataset, device = device))

| Arguments Keep work_dir: D:\pycharm_workspace\TradeMaster\work_dir/portfolio_management_dj30_sarl_sarl_adam_mse


## Step5: Train, Valid and Test

In [7]:
trainer.train_and_valid()

2023-02-28 14:34:31,050	INFO trainer.py:2321 -- Executing eagerly (framework='tf2'), with eager_tracing=False. For production workloads, make sure to set eager_tracing=True  in order to match the speed of tf-static-graph (framework='tf'). For debugging purposes, `eager_tracing=False` is the best choice.
2023-02-28 14:34:31,053	INFO trainer.py:903 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


Train Episode: [1/20]
Valid Episode: [1/20]
+--------------+-------------+------------+--------------+
| Total Return | Sharp Ratio | Volatility | Max Drawdown |
+--------------+-------------+------------+--------------+
|  7.953910%   |   3.778823  | 2.121771%  |  31.771785%  |
+--------------+-------------+------------+--------------+
Train Episode: [2/20]
+--------------+-------------+------------+--------------+
| Total Return | Sharp Ratio | Volatility | Max Drawdown |
+--------------+-------------+------------+--------------+
| -16.804172%  |  -7.464375  | 0.797918%  |  59.806837%  |
+--------------+-------------+------------+--------------+
Valid Episode: [2/20]
+--------------+-------------+------------+--------------+
| Total Return | Sharp Ratio | Volatility | Max Drawdown |
+--------------+-------------+------------+--------------+
|  11.938123%  |   5.751683  | 2.092259%  |  31.535895%  |
+--------------+-------------+------------+--------------+
Train Episode: [3/20]
Valid

KeyboardInterrupt: 

In [None]:
import ray
from ray.tune.registry import register_env
from trademaster.environments.portfolio_management.sarl_environment import PortfolioManagementSARLEnvironment
def env_creator(env_name):
    if env_name == 'portfolio_management_sarl':
        env = PortfolioManagementSARLEnvironment
    else:
        raise NotImplementedError
    return env
ray.init(ignore_reinit_error=True)
register_env("portfolio_management_sarl", lambda config: env_creator("portfolio_management_sarl")(config))
trainer.test()

In [None]:
plot(trainer.test_environment.save_asset_memory(),alg="SARL")