In [1]:
%reload_ext autoreload

In [2]:
%autoreload 2

In [3]:
import warnings

In [4]:
warnings.filterwarnings('ignore')

In [5]:
import datetime
import glob
import mlflow
import os

import numpy as np
import pandas as pd

from pathlib import Path

In [6]:
import qlib

from qlib.data.dataset import DataHandlerLP
from qlib.constant import REG_CN, REG_US
from qlib.contrib.report import analysis_model, analysis_position
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import PortAnaRecord, SigAnaRecord

In [7]:
from longcapital.data.dataset.processor import ChangeInstrument, DropInstrument, Fillna
from longcapital.utils.io import get_params_from_file, update_params_to_file, update_report_df
from longcapital.utils.time import get_diff_date
from longcapital.workflow.record_temp import SignalRecord


`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [8]:
EXP_NAME = "long-capital"

In [9]:
# records
STRATEGY_PARAMS_FILE = "../data/params/strategy.json"
MODEL_PARAMS_FILE = "../data/params/model.json"
PERFORMANCE_FILE = "../data/params/performance.json"
REPORT_DF_FOLDER = "../data/report_df"

In [10]:
# model
MODEL_VALID_LOSS_KEY_DICT = {
    "mse": "l2",
    "mse_log": "l2",
    "binary": "binary_logloss",
    "lambdarank": "ndcg@5",
}
MODEL_LOSS_NAME_DICT = {
    "mse": "mse",
    "mse_log": "mse",
    "binary": "binary",
    "lambdarank": "lambdarank"
}

# strategy
BECHMARK_PARAMS = {
    "csi300": "SH000300",
    "csi500": "SH000905",
    "csi800": "SH000906",
    # https://github.com/microsoft/qlib/issues/720
    "SP500": "^gspc",
    "NASDAQ100": "^ndx",
}

In [11]:
def get_last_date_from_calendar(region=REG_CN):
    file = f"~/.qlib/qlib_data/{region}_data/calendars/day.txt"
    date = os.popen(f"tail -n 1 {file}").read().split("\n")[0]
    return date


def get_date_config(region=REG_CN, pred_date=None):
    if pred_date is None:
        pred_date = get_last_date_from_calendar(region)
    test_end_date = pred_date
    backtest_end_date = get_diff_date(pred_date, -1)
    DATE_CONFIG = {
        REG_CN: {
            "train": {
                "start": "2006-01-01",
                "end": "2016-12-31"
            },
            "valid": {
                "start": "2017-01-01",
                "end": "2018-12-31"
            },
            "test": {
                "start": "2019-01-01",
                "end": test_end_date
            },
            "backtest": {
                "start": "2019-01-01",
                "end": backtest_end_date
            }
        },
        REG_US: {
            "train": {
                "start": "2006-01-01",
                "end": "2016-12-31"
            },
            "valid": {
                "start": "2017-01-01",
                "end": "2018-12-31"
            },
            "test": {
                "start": "2019-01-01",
                "end": test_end_date
            },
            "backtest": {
                "start": "2019-01-01",
                "end": backtest_end_date
            }
        }
    }
    return DATE_CONFIG[region]


def get_backtest_config(region=REG_CN, instruments="csi300", deal_price="open"):
    REGION_CONFIG = {
        REG_CN: {
            "benchmark": BECHMARK_PARAMS[instruments],
            "exchange_kwargs": {
                "codes": instruments,
                "freq": "day",
                "trade_unit": 100,
                "limit_threshold": 0.095,
                "deal_price": deal_price,
                "open_cost": 0.0005,
                "close_cost": 0.0015,
                "min_cost": 5,
            }
        },
        REG_US: {
            "benchmark": BECHMARK_PARAMS[instruments],
            "exchange_kwargs": {
                "codes": instruments,
                "freq": "day",
                "trade_unit": 1,
                "limit_threshold": None,
                "deal_price": deal_price,
                # estimated from moomoo sg
                "open_cost": 0.003,
                "close_cost": 0.005,
                "min_cost": 0
            }
        }
    }
    return REGION_CONFIG[region]

In [12]:
def get_all_config(
    region=REG_CN,
    instruments="csi300",
    benchmark_feature=None,
    deal_price="open",
    days_ahead=4, 
    loss_type="mse",
    label_norm="CSZScoreNorm",
    model_type="default", 
    strategy_type="best",
    hold_thresh=2, 
    date_config=None
):
    config = {
        # market
        # [REG_CN, REG_US]
        "region": region,
        # ["csi300", "csi500", "csi800", "csiall", "all"]
        "instruments": instruments,

        # feature
        # [None, "raw", "diff", "both"]
        "benchmark_feature": benchmark_feature,

        # label
        # ["open", "close"]
        "deal_price": deal_price,
        # [1,2,3,...]
        "days_ahead": days_ahead,
        # ["mse", "mse_log", "binary", "lambdarank"]
        "loss_type": loss_type,
        # ["CSZScoreNorm", "CSRankNorm"]
        "label_norm": label_norm,
        
        # model
        # ["default", "best"]
        "model_type": model_type,

        # strategy
        "strategy_type": "best",
        # [1,2,3,...]
        "hold_thresh": hold_thresh
    }
    dataset_key = "-".join([f"{k}={v}" for k,v in config.items() if k not in ["model_type","strategy_type","hold_thresh"]])
    model_key = "-".join([f"{k}={v}" for k,v in config.items() if k not in ["strategy_type","hold_thresh"]])
    config_key = "-".join([f"{k}={v}" for k,v in config.items()])
    if date_config is None:
        date_config = get_date_config(region=config["region"])
    config["date"] = date_config
    config["backtest"] = get_backtest_config(
        region=config["region"], 
        instruments=config["instruments"], 
        deal_price=config["deal_price"]
    )
    config.update({
        "dataset_key": dataset_key,
        "model_key": model_key,
        "config_key": config_key,
    })
    return config

In [13]:
def append_benchmark_to_pool(region, instrument, benchmark):
    folder = f"/Users/chenglong.chen/.qlib/qlib_data/{region}_data/instruments"
    # already appended
    with open(f"{folder}/{instrument}.txt", "r") as f:
        for line in f.readlines():
            if benchmark in line:
                return
    # append
    with open(f"{folder}/all.txt", "r") as f:
        for line in f.readlines():
            if benchmark in line:
                break
    with open(f"{folder}/{instrument}.txt", "a") as f:
        f.write(line)

In [14]:
def load_dataset(config, label=None):
    # processors
    fields = []
    names = []
    if config["benchmark_feature"] in ["raw","diff","both"]:
        infer_processors = [
            {"class": "Fillna"},
#             Fillna(fields_group="feature"),
            ChangeInstrument(instrument=BECHMARK_PARAMS[config["instruments"]], append_type=config["benchmark_feature"], fields_group="feature"),
            DropInstrument(instruments=[BECHMARK_PARAMS[config["instruments"]]])
        ]
        append_benchmark_to_pool(config["region"], config["instruments"], BECHMARK_PARAMS[config["instruments"]])
    else:
        infer_processors = [
            {"class": "Fillna"},
#             Fillna(fields_group="feature"),
            DropInstrument(instruments=[BECHMARK_PARAMS[config["instruments"]]])
        ]
    learn_processors = [
        {"class": "DropnaLabel"}
    ]
    if config["loss_type"] not in ["lambdarank"]:
        learn_processors += [
            {"class": config["label_norm"], "kwargs": {"fields_group": "label"}},
        ]
    
    # handler
    data_start_time = min(
        config["date"]["train"]["start"], 
        config["date"]["valid"]["start"], 
        config["date"]["test"]["start"]
    )
    data_end_time = max(
        config["date"]["train"]["end"], 
        config["date"]["valid"]["end"], 
        config["date"]["test"]["end"]
    )
    
    handler_kwargs = {
        "start_time": data_start_time,
        "end_time": data_end_time,
        "fit_start_time": config["date"]["train"]["start"],
        "fit_end_time": config["date"]["train"]["end"],
        "instruments": config["instruments"],
        "feature": (fields, names),
        "learn_processors": learn_processors,
        "infer_processors": infer_processors,
        "loss_type": config["loss_type"],
        "next_label_price_expr": f"$open",
        "curr_label_price_expr": f"$open",
        "days_ahead": config["days_ahead"],
        "include_volume": False,
    }
    if label:
        handler_kwargs["label"] = label
    handler_conf = {
        "class": "Alpha158",
        "module_path": "longcapital.contrib.data.handler",
        "kwargs": handler_kwargs,
    }
    hd = init_instance_by_config(handler_conf)
    
    # dataset
    dataset_conf = {
        "class": "DatasetH",
        "module_path": "qlib.data.dataset",
        "kwargs": {
            "handler": hd,
            "segments": {
                "train": (config["date"]["train"]["start"], config["date"]["train"]["end"]),
                "valid": (config["date"]["valid"]["start"], config["date"]["valid"]["end"]),
                "test": (config["date"]["test"]["start"], config["date"]["test"]["end"]),
            },
        },
    }
    dataset = init_instance_by_config(dataset_conf)
    
    # nan check
    df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
    df_test = dataset.prepare("test", col_set=["feature"], data_key=DataHandlerLP.DK_I)
    m = np.isfinite(df_train["feature"].values).mean(axis=0) != 1
    if m.sum():
        print(df_train["feature"].columns[m])
    m = np.isfinite(df_test["feature"].values).mean(axis=0) != 1
    if m.sum():
        print(df_test["feature"].columns[m])
    
    return dataset

In [15]:
def train_model(dataset, config):
    mlflow.end_run()
    with R.start(experiment_name=EXP_NAME):
        if get_params_from_file(MODEL_PARAMS_FILE, config["config_key"]) is None:
            model_params = get_params_from_file(MODEL_PARAMS_FILE, "default")
            update_params_to_file(MODEL_PARAMS_FILE, config["config_key"], model_params)
        else:
            model_params = get_params_from_file(MODEL_PARAMS_FILE, config["config_key"])
        model = init_instance_by_config({
            "class": "LGBModel",
            "module_path": "longcapital.contrib.model.gbdt",
            "kwargs": model_params
        })
        model.fit(dataset)
        
        R.save_objects(trained_model=model)

        rec = R.get_recorder()
        rid = rec.id # save the record id

        # Inference and saving signal
        sr = SignalRecord(model, dataset, rec, neutralize=False, riskiest_features_num=50)
        sr.generate()
        
    return model, rid

In [16]:
def prepare_signal(days_aheads, signal_names, hold_thresh):
    df_valid = None
    df_test = None
    for days_ahead, signal_name in zip(days_aheads, signal_names):
        config = get_all_config(days_ahead=days_ahead, hold_thresh=hold_thresh)
        qlib.init(provider_uri=f"~/.qlib/qlib_data/{config['region']}_data", region=config["region"])
        dataset = load_dataset(config)
        model, rid = train_model(dataset, config)
        if df_valid is None:
            df_valid = dataset.prepare("valid", col_set=["feature"], data_key=DataHandlerLP.DK_L)
        if df_test is None:
            df_test = dataset.prepare("test", col_set=["feature"], data_key=DataHandlerLP.DK_I)
        df_valid.insert(0, ("feature",signal_name), model.predict(dataset, "valid"))
        df_test.insert(0, ("feature",signal_name), model.predict(dataset, "test"))
    signal = pd.concat([df_valid, df_test], axis=0)
    return signal, rid

In [17]:
# trading config
topk = 10
n_drop = 2
hold_thresh = 3
account = 100000000

In [18]:
# the following signal will be inserted in the 0 index one by one
# so finally, the first signal will be days_aheads[-1] and signal_names[-1] at last position
# which will be used in TopkDropoutStrategy as default signal for ranking
# (see: https://github.com/microsoft/qlib/blob/main/qlib/contrib/strategy/signal_strategy.py#L147)
# and also will be used in ranking the features to get the obs/state space for RL training.
# (see: TradeStrategy.get_feature)
days_aheads = [2, 3, 4, 5, 6, 7, 8][-1:]
signal_names = ["signal_2", "signal_3", "signal_4", "signal_5", "signal_6", "signal_7", "signal"][-1:]
signal_key = "signal"

# model config
# alpha158 + signals + position flag + unhold flag
dim = 158 + len(signal_names) + 1 + 1
# number of stock candidates for ranking
stock_num = 20

In [19]:
config = get_all_config(days_ahead=days_aheads[-1], hold_thresh=hold_thresh)
signal, rid = prepare_signal(days_aheads, signal_names, hold_thresh)

[67924:MainThread](2023-03-14 21:47:08,461) INFO - qlib.Initialization - [config.py:416] - default_conf: client.
[67924:MainThread](2023-03-14 21:47:08,469) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[67924:MainThread](2023-03-14 21:47:08,470) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/Users/chenglong.chen/.qlib/qlib_data/cn_data')}
[67924:MainThread](2023-03-14 21:48:17,385) INFO - qlib.timer - [log.py:128] - Time cost: 68.907s | Loading data Done
[67924:MainThread](2023-03-14 21:48:23,786) INFO - qlib.timer - [log.py:128] - Time cost: 4.756s | Fillna Done
[67924:MainThread](2023-03-14 21:48:24,461) INFO - qlib.timer - [log.py:128] - Time cost: 0.674s | DropInstrument Done
[67924:MainThread](2023-03-14 21:48:25,369) INFO - qlib.timer - [log.py:128] - Time cost: 0.412s | DropnaLabel Done
[67924:MainThread](2023-03-14 21:48:33,163) INFO - qlib.timer - [log.py:128] - Time cost: 7.79

ModuleNotFoundError. CatBoostModel are skipped. (optional: maybe installing CatBoostModel can fix it.)



Please use `line_search_wolfe2` from the `scipy.optimize` namespace, the `scipy.optimize.linesearch` namespace is deprecated.


Please use `line_search_wolfe1` from the `scipy.optimize` namespace, the `scipy.optimize.linesearch` namespace is deprecated.



Training until validation scores don't improve for 50 rounds
[20]	train's l2: 0.961997	valid's l2: 0.986656
[40]	train's l2: 0.941246	valid's l2: 0.985369
[60]	train's l2: 0.924647	valid's l2: 0.986241
[80]	train's l2: 0.909259	valid's l2: 0.98719
Early stopping, best iteration is:
[36]	train's l2: 0.945509	valid's l2: 0.985177


[67924:MainThread](2023-03-14 21:49:04,018) INFO - qlib.workflow - [record_temp.py:196] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 3


'The following are prediction results of the LGBModel model.'
                          score
datetime   instrument          
2019-01-02 SH600000   -0.077803
           SH600004   -0.108537
           SH600009    0.248091
           SH600010   -0.010179
           SH600011    0.030957


[67924:MainThread](2023-03-14 21:49:04,970) INFO - qlib.timer - [log.py:128] - Time cost: 0.000s | waiting `async_log` Done


# RL

In [20]:
from qlib.contrib.evaluate import risk_analysis
from qlib.rl.trainer import Checkpoint, EarlyStopping, MetricsWriter, train, backtest
from qlib.rl.utils.log import CsvWriter


`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.


`np.typeDict` is a deprecated alias for `np.sctypeDict`.



In [21]:
from tianshou.data import Batch


`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [22]:
from longcapital.rl.order_execution.reward import (
    ExcessReturnReward, 
    EpisodeInformationRatioReward, 
    ExecutionInformationRatioReward, 
    ExcessExecutionInformationRatioReward
)
from longcapital.rl.order_execution.state import TradeStrategyInitiateState
from longcapital.rl.order_execution.strategy import (
    TopkDropoutStrategy, 
    TopkDropoutSignalStrategy, 
    TopkStrategy, 
    WeightStrategy
)
from longcapital.rl.order_execution.simulator import TradeStrategySimulator
from longcapital.rl.order_execution.policy import continuous
from longcapital.rl.order_execution.policy import discrete


Please use `spmatrix` from the `scipy.sparse` namespace, the `scipy.sparse.base` namespace is deprecated.



## Check Simulator

In [23]:
def check_simulator(trade_strategy, simulator):
    reward = EpisodeInformationRatioReward(scale=1.)
    state = simulator.get_state()
    obs = [{"obs": trade_strategy.state_interpreter.interpret(state), "info": {}}]

    policy_out = trade_strategy.policy(Batch(obs))

    act = trade_strategy.action_interpreter.interpret(state, policy_out.act)
    print(f"Action = {act}")

    simulator.step(act)
    rew = float(reward.reward(simulator.get_state()))
    print(f"Reward = {rew:.6f}")

In [24]:
initial_states_train = [
    TradeStrategyInitiateState(
        start_time=get_diff_date(config["date"]["valid"]["start"], 7), # to avoid start_time not tradable
        end_time=config["date"]["valid"]["end"],
        sample_date=False
    )
]
initial_states_valid = [
    TradeStrategyInitiateState(
        start_time=config["date"]["backtest"]["start"],
        end_time=config["date"]["backtest"]["end"],
        sample_date=False
    )
]

In [25]:
topk_dropout_strategy = TopkDropoutStrategy(
    signal=signal,
    dim=dim,
    stock_num=stock_num,
    topk=topk,
    n_drop=n_drop,
    only_tradable=True,
    hold_thresh=hold_thresh,
    signal_key="signal",
    policy_cls=discrete.PPO
)

In [26]:
topk_dropout_simulator = TradeStrategySimulator(
    trade_strategy=topk_dropout_strategy, 
    initial_state=initial_states_train[0], 
    account=account,
    benchmark=config["backtest"]["benchmark"],
    exchange_kwargs=config["backtest"]["exchange_kwargs"]
)


`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.

[67924:MainThread](2023-03-14 21:49:10,616) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


start_time: 2017-01-08, end_time: 2018-12-31


In [27]:
check_simulator(topk_dropout_strategy, topk_dropout_simulator)

Action = TopkDropoutStrategyAction(n_drop=0)
Reward = 0.000000



`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [28]:
topk_dropout_signal_strategy = TopkDropoutSignalStrategy(
    signal=signal,
    dim=dim,
    stock_num=stock_num,
    topk=topk,
    n_drop=n_drop,
    only_tradable=True,
    hold_thresh=hold_thresh,
    signal_key="signal",
    policy_cls=continuous.MetaPPO
)


`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [29]:
topk_dropout_signal_simulator = TradeStrategySimulator(
    trade_strategy=topk_dropout_signal_strategy, 
    initial_state=initial_states_train[0], 
    account=account,
    benchmark=config["backtest"]["benchmark"],
    exchange_kwargs=config["backtest"]["exchange_kwargs"]
)

[67924:MainThread](2023-03-14 21:49:23,136) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


start_time: 2017-01-08, end_time: 2018-12-31


In [30]:
check_simulator(topk_dropout_signal_strategy, topk_dropout_signal_simulator)

Action = TopkDropoutSignalStrategyAction(signal=instrument
SH600519   -1.822446
SH600703    0.139540
SZ000568    1.389984
SH600660    1.147429
SZ000858    0.956798
SZ002304   -1.639224
SH600196    0.550667
SZ300070    0.182697
SZ000768    0.021002
SH600023   -0.250455
SH600297   -1.442085
SH601888    1.721602
SZ002475   -0.771709
SH600383   -0.241921
SZ002008   -1.223077
SZ000063   -0.950663
SH601333   -0.308749
SH600372   -0.384635
SH600688   -0.279528
SH601939   -0.582056
Name: (feature, signal), dtype: float64)
Reward = 0.000000



`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [31]:
topk_strategy = TopkStrategy(
    signal=signal,
    dim=dim,
    stock_num=stock_num,
    signal_key="signal",
    policy_cls=discrete.MetaPPO
)


`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [32]:
topk_simulator = TradeStrategySimulator(
    trade_strategy=topk_strategy, 
    initial_state=initial_states_train[0], 
    account=account,
    benchmark=config["backtest"]["benchmark"],
    exchange_kwargs=config["backtest"]["exchange_kwargs"]
)

[67924:MainThread](2023-03-14 21:49:34,131) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


start_time: 2017-01-08, end_time: 2018-12-31


In [33]:
check_simulator(topk_strategy, topk_simulator)

Action = WeightStrategyAction(target_weight_position={'SH600519': 0.16666666666666666, 'SZ000568': 0.16666666666666666, 'SH600660': 0.16666666666666666, 'SH600196': 0.16666666666666666, 'SH600297': 0.16666666666666666, 'SH600383': 0.16666666666666666})
Reward = 0.000000



`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [34]:
weight_strategy = WeightStrategy(
    signal=signal,
    dim=dim,
    stock_num=stock_num,
    topk=topk,
    signal_key="signal",
    policy_cls=continuous.MetaPPO
)


`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [35]:
weight_strategy_simulator = TradeStrategySimulator(
    trade_strategy=weight_strategy, 
    initial_state=initial_states_train[0], 
    account=account,
    benchmark=config["backtest"]["benchmark"],
    exchange_kwargs=config["backtest"]["exchange_kwargs"]
)

[67924:MainThread](2023-03-14 21:49:45,104) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


start_time: 2017-01-08, end_time: 2018-12-31


In [36]:
check_simulator(weight_strategy, weight_strategy_simulator)

Action = WeightStrategyAction(target_weight_position={'SZ000768': 0.1, 'SH600519': 0.1, 'SH600688': 0.1, 'SZ000063': 0.1, 'SZ002008': 0.1, 'SH601939': 0.1, 'SH600297': 0.1, 'SZ000858': 0.1, 'SH600372': 0.1, 'SH600023': 0.1})
Reward = 0.000000



`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



## Train Policy

In [37]:
def train_trade_strategy(trade_strategy, max_iters=2, concurrency=1, episode_per_iter=1):
    def _exploration_noise():
        if trade_strategy.policy_cls in [continuous.MetaDDPG, continuous.MetaTD3, continuous.MetaSAC]:
            return True
        return False
    
    def _start_episodes():
        if trade_strategy.policy_cls in [continuous.MetaDDPG, continuous.MetaTD3, continuous.MetaSAC]:
            return 5 * episode_per_iter
        return None
    
    reward = EpisodeInformationRatioReward(scale=1.)
    
    output_dir = f'../records/EpisodeInformationRatioReward/{trade_strategy}/{trade_strategy.policy}/{datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")}'
    
    trainer_kwargs = {
        "max_iters": max_iters,
        "finite_env_type": "dummy",
        "concurrency": concurrency,
        "val_every_n_iters": 1,
        "callbacks": [
            Checkpoint(
                dirpath=Path(f"{output_dir}/checkpoints"),
                every_n_iters=1,
                save_latest="copy",
            ),
            EarlyStopping(
                monitor="reward", 
                min_delta=0.0, 
                patience=1000, 
                restore_best_weights=True
            ),
            MetricsWriter(
                dirpath=Path(f"{output_dir}/metrics")
            ),
        ],
    }
    vessel_kwargs = {
        "update_kwargs": {"batch_size": 64, "repeat": 5},
        "episode_per_iter": episode_per_iter,
        "val_initial_states": initial_states_valid,
        "exploration_noise": _exploration_noise(),
        "start_episodes": _start_episodes(),
    }
    
    simulator_fn=lambda initial_state: TradeStrategySimulator(
        trade_strategy=trade_strategy, 
        initial_state=initial_state, 
        account=account,
        benchmark=config["backtest"]["benchmark"],
        exchange_kwargs=config["backtest"]["exchange_kwargs"]
    )
    
    # baseline
    baseline_logger = CsvWriter(
        output_dir=Path(f"{output_dir}/baseline")
    )
    backtest(
        simulator_fn=simulator_fn,
        state_interpreter=trade_strategy.state_interpreter,
        action_interpreter=trade_strategy.baseline_action_interpreter,
        policy=trade_strategy.policy,
        reward=reward,
        initial_states=initial_states_valid,
        finite_env_type=trainer_kwargs["finite_env_type"],
        logger=[baseline_logger]
    )
    del baseline_logger
    
    # train
    train(
        simulator_fn=simulator_fn,
        state_interpreter=trade_strategy.state_interpreter,
        action_interpreter=trade_strategy.action_interpreter,
        policy=trade_strategy.policy,
        reward=reward,
        initial_states=initial_states_train,
        trainer_kwargs=trainer_kwargs,
        vessel_kwargs=vessel_kwargs
    )


`should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.



In [38]:
# train_trade_strategy(trade_strategy=topk_dropout_strategy, max_iters=1, concurrency=1)

In [None]:
train_trade_strategy(trade_strategy=topk_dropout_signal_strategy, max_iters=1000, concurrency=1)

[67924:MainThread](2023-03-14 21:49:56,065) INFO - qlib.rl.trainer.vessel - [vessel.py:163] - Testing initial states collection size: 1
[67924:MainThread](2023-03-14 21:49:56,067) INFO - qlib.rl.utils.data_queue - [data_queue.py:70] - Automatically set data queue maxsize to 12 to avoid overwhelming.
[67924:MainThread](2023-03-14 21:49:57,249) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


start_time: 2019-01-01, end_time: 2023-03-13


[67924:MainThread](2023-03-14 21:50:28,361) INFO - qlib.rl.trainer.trainer - [trainer.py:211] - 
2023-03-14 21:50:28	Train iteration 1/1000
[67924:MainThread](2023-03-14 21:50:28,363) INFO - qlib.rl.trainer.vessel - [vessel.py:148] - Training initial states collection size: 1
[67924:MainThread](2023-03-14 21:50:28,365) INFO - qlib.rl.utils.data_queue - [data_queue.py:70] - Automatically set data queue maxsize to 12 to avoid overwhelming.
[67924:MainThread](2023-03-14 21:50:28,371) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


start_time: 2017-01-08, end_time: 2018-12-31


[67924:MainThread](2023-03-14 21:50:48,417) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange
[67924:MainThread](2023-03-14 21:50:59,764) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


start_time: 2017-01-08, end_time: 2018-12-31




start_time: 2017-01-08, end_time: 2018-12-31


[67924:MainThread](2023-03-14 21:51:13,157) INFO - qlib.rl.trainer.vessel - [vessel.py:83] - [Iter 1] n/ep = 1
[67924:MainThread](2023-03-14 21:51:13,159) INFO - qlib.rl.trainer.vessel - [vessel.py:83] - [Iter 1] n/st = 483
[67924:MainThread](2023-03-14 21:51:13,160) INFO - qlib.rl.trainer.vessel - [vessel.py:83] - [Iter 1] rews = 0.7329425479242712
[67924:MainThread](2023-03-14 21:51:13,162) INFO - qlib.rl.trainer.vessel - [vessel.py:83] - [Iter 1] lens = 483.0
[67924:MainThread](2023-03-14 21:51:13,163) INFO - qlib.rl.trainer.vessel - [vessel.py:83] - [Iter 1] idxs = 0.0
[67924:MainThread](2023-03-14 21:51:13,165) INFO - qlib.rl.trainer.vessel - [vessel.py:83] - [Iter 1] rew = 0.7329425479242712
[67924:MainThread](2023-03-14 21:51:13,166) INFO - qlib.rl.trainer.vessel - [vessel.py:83] - [Iter 1] len = 483.0
[67924:MainThread](2023-03-14 21:51:13,168) INFO - qlib.rl.trainer.vessel - [vessel.py:83] - [Iter 1] rew_std = 0.0
[67924:MainThread](2023-03-14 21:51:13,169) INFO - qlib.rl.trai

start_time: 2019-01-01, end_time: 2023-03-13


[67924:MainThread](2023-03-14 21:51:44,781) INFO - qlib.rl.trainer.callbacks - [callbacks.py:161] - #0 current reward: 0.7329, best reward: 0.7329 in #0
[67924:MainThread](2023-03-14 21:51:44,806) INFO - qlib.rl.trainer.trainer - [trainer.py:211] - 
2023-03-14 21:51:44	Train iteration 2/1000
[67924:MainThread](2023-03-14 21:51:44,808) INFO - qlib.rl.trainer.vessel - [vessel.py:148] - Training initial states collection size: 1
[67924:MainThread](2023-03-14 21:51:44,809) INFO - qlib.rl.utils.data_queue - [data_queue.py:70] - Automatically set data queue maxsize to 12 to avoid overwhelming.
[67924:MainThread](2023-03-14 21:51:45,817) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


start_time: 2017-01-08, end_time: 2018-12-31


In [None]:
# train_trade_strategy(trade_strategy=topk_strategy, max_iters=1, concurrency=1)

In [None]:
# train_trade_strategy(trade_strategy=weight_strategy, max_iters=1, concurrency=1)

## Backtest Policy

In [None]:
def get_port_analysis_config(config, strategy_config):
    port_analysis_config = {
        "executor": {
            "class": "SimulatorExecutor",
            "module_path": "qlib.backtest.executor",
            "kwargs": {
                "time_per_step": "day",
                "generate_portfolio_metrics": True,
            },
        },
        "strategy": strategy_config,
        "backtest": {
            "start_time": config["date"]["backtest"]["start"],
            "end_time": config["date"]["backtest"]["end"],
            "account": account,
            "benchmark": config["backtest"]["benchmark"],
            "exchange_kwargs": config["backtest"]["exchange_kwargs"]
        },
    }
    return port_analysis_config

In [None]:
def get_topk_dropout_port_analysis_config(config):
    strategy_config = {
        "class": "TopkDropoutStrategy",
        "module_path": "qlib.contrib.strategy.signal_strategy",
        "kwargs": {
            "signal": signal,
            "topk": topk,
            "n_drop": n_drop,
            "risk_degree": 0.95,
            "only_tradable": True,
            "hold_thresh": hold_thresh,
        }
    }
    return get_port_analysis_config(config, strategy_config)

In [None]:
def get_latest_checkpoint_path(strategy):
    output_dir = sorted(glob.glob(f"../records/EpisodeInformationRatioReward/{strategy}/{strategy.policy}/*"))[-1]
    checkpoint_path = f"./{output_dir}/checkpoints/latest.pth"
    print(checkpoint_path)
    return checkpoint_path


def get_best_checkpoint_path(strategy):
    output_dir = sorted(glob.glob(f"../records/EpisodeInformationRatioReward/{strategy}/{strategy.policy}/*"))[-1]
    df_valid = pd.read_csv(f"{output_dir}/metrics/validation_result.csv", index_col=0)
    best = df_valid.sort_values("val/reward", ascending=True).iloc[-1]
    epoch = int(best.name) + 1
    checkpoint_path = f"./{output_dir}/checkpoints/{epoch:03d}.pth"
    print(checkpoint_path)
    print(best["val/reward"])
    return checkpoint_path

In [None]:
def get_topk_dropout_strategy_port_analysis_config(config, checkpoint_path=None):
    if checkpoint_path is None:
        checkpoint_path = get_best_checkpoint_path(topk_dropout_strategy)
    strategy_config = {
        "class": "TopkDropoutStrategy",
        "module_path": "longcapital.rl.order_execution.strategy",
        "kwargs": {
            "signal": signal,
            "dim": dim,
            "stock_num": stock_num,
            "topk": topk,
            "risk_degree": 0.95,
            "only_tradable": True,
            "hold_thresh": hold_thresh,
            "signal_key": "signal",
            "policy_cls": topk_dropout_strategy.policy_cls,
            "checkpoint_path": checkpoint_path
        }
    }
    return get_port_analysis_config(config, strategy_config)

In [None]:
def get_topk_dropout_signal_strategy_port_analysis_config(config, checkpoint_path=None):
    if checkpoint_path is None:
        checkpoint_path = get_best_checkpoint_path(topk_dropout_signal_strategy)
    strategy_config = {
        "class": "TopkDropoutSignalStrategy",
        "module_path": "longcapital.rl.order_execution.strategy",
        "kwargs": {
            "signal": signal,
            "dim": dim,
            "stock_num": stock_num,
            "topk": topk,
            "n_drop": n_drop,
            "risk_degree": 0.95,
            "only_tradable": True,
            "hold_thresh": hold_thresh,
            "signal_key": "signal",
            "policy_cls": topk_dropout_signal_strategy.policy_cls,
            "checkpoint_path": checkpoint_path
        }
    }
    return get_port_analysis_config(config, strategy_config)

In [None]:
def get_topk_strategy_port_analysis_config(config, checkpoint_path=None):
    if checkpoint_path is None:
        checkpoint_path = get_best_checkpoint_path(topk_strategy)
    strategy_config = {
        "class": "TopkStrategy",
        "module_path": "longcapital.rl.order_execution.strategy",
        "kwargs": {
            "signal": signal,
            "dim": dim,
            "stock_num": stock_num,
            "policy_cls": topk_strategy.policy_cls,
            "checkpoint_path": checkpoint_path
        }
    }
    return get_port_analysis_config(config, strategy_config)

In [None]:
def get_weight_strategy_port_analysis_config(config, checkpoint_path=None):
    if checkpoint_path is None:
        checkpoint_path = get_best_checkpoint_path(weight_strategy)
    strategy_config = {
        "class": "WeightStrategy",
        "module_path": "longcapital.rl.order_execution.strategy",
        "kwargs": {
            "signal": signal,
            "dim": dim,
            "stock_num": stock_num,
            "topk": topk,
            "signal_key": "signal",
            "policy_cls": weight_strategy.policy_cls,
            "checkpoint_path": checkpoint_path
        }
    }
    return get_port_analysis_config(config, strategy_config)

In [None]:
def run_trade_strategy_backtest(rid, port_analysis_config, start_time=None, end_time=None):
    mlflow.end_run()
    with R.start(experiment_name=EXP_NAME, recorder_id=rid, resume=True):
        if start_time:
            port_analysis_config["backtest"]["start_time"] = start_time
        if end_time:
            port_analysis_config["backtest"]["end_time"] = end_time
        
        rec = R.get_recorder()
        par = PortAnaRecord(rec, port_analysis_config, skip_existing=False)
        par.generate()
        
        analysis_df = rec.load_object("portfolio_analysis/port_analysis_1day.pkl")
        report_normal_df = rec.load_object("portfolio_analysis/report_normal_1day.pkl")
        print(analysis_df)
        analysis_position.report_graph(report_normal_df)

In [None]:
run_trade_strategy_backtest(
    rid, 
    port_analysis_config=get_topk_dropout_port_analysis_config(
        config=config
    )
)

In [None]:
# run_trade_strategy_backtest(
#     rid, 
#     port_analysis_config=get_topk_dropout_strategy_port_analysis_config(
#         config,
#     )
# )

In [None]:
run_trade_strategy_backtest(
    rid, 
    port_analysis_config=get_topk_dropout_signal_strategy_port_analysis_config(
        config, 
    )
)

In [None]:
# run_trade_strategy_backtest(
#     rid, 
#     port_analysis_config=get_topk_strategy_port_analysis_config(
#         config, 
#     )
# )

In [None]:
# run_trade_strategy_backtest(
#     rid, 
#     port_analysis_config=get_weight_strategy_port_analysis_config(
#         config, 
#     )
# )

# Trade

In [None]:
initial_states_test = [
    TradeStrategyInitiateState(
        start_time=config["date"]["test"]["end"],
        end_time=config["date"]["test"]["end"],
        sample_date=False
    )
]

In [None]:
account = {
    "cash": 17392,
    "SH601985": {"amount": 4000, "price": 6.49, "count_day": 1},
    "SH601225": {"amount": 1200, "price": 19.34, "count_day": 6},
    "SH603833": {"amount": 200, "price": 129.69, "count_day": 4},
    "SH600188": {"amount": 700, "price": 32.46, "count_day": 6},
    "SZ002032": {"amount": 500, "price": 56.09, "count_day": 6},
    "SH603986": {"amount": 200, "price": 101.03, "count_day": 2},
}

In [None]:
topk_dropout_signal_strategy = TopkDropoutSignalStrategy(
    signal=signal,
    dim=dim,
    stock_num=stock_num,
    topk=10,
    n_drop=2,
    only_tradable=False, # we don't know it before trading
    hold_thresh=hold_thresh,
    signal_key="signal",
    policy_cls=continuous.MetaPPO,
    checkpoint_path=get_best_checkpoint_path(topk_dropout_signal_strategy)
)

In [None]:
topk_dropout_signal_simulator = TradeStrategySimulator(
    trade_strategy=topk_dropout_signal_strategy, 
    initial_state=initial_states_test[0], 
    account=account,
    benchmark=None,
    exchange_kwargs=config["backtest"]["exchange_kwargs"]
)

In [None]:
decision = topk_dropout_signal_strategy.trade()

In [None]:
decision