<a href="https://colab.research.google.com/github/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Workflow of Qlib

## Importing Dependencies

In [15]:
from datetime import datetime
import os
from pathlib import Path
import pickle
import sys
import qlib
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
from qlib.contrib.model.pytorch_master_ts import MASTERModel
from qlib.contrib.data.dataset import MASTERTSDatasetH
from qlib.contrib.data.handler import Alpha158

## Initializing Qlib

In [16]:
from qlib.tests.data import GetData


provider_uri = "~/QuantProject/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)

	If downloading is required: `exists_skip=False` or `change target_dir`[0m
[3908378:MainThread](2025-04-07 00:07:38,874) INFO - qlib.Initialization - [config.py:420] - default_conf: client.
[3908378:MainThread](2025-04-07 00:07:38,879) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[3908378:MainThread](2025-04-07 00:07:38,880) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': PosixPath('/home/24039378g/QuantProject/.qlib/qlib_data/cn_data')}


## Configuration of Workflow

In [17]:
# 配置参数
market = "csi300"
benchmark = "SH000300"

# 数据处理器配置
data_handler_config = {
    "start_time": "2008-01-01",
    "end_time": "2020-08-01",
    "fit_start_time": "2008-01-01",
    "fit_end_time": "2014-12-31",
    "instruments": market,
    "infer_processors": [
        {
            "class": "RobustZScoreNorm",
            "kwargs": {
                "fields_group": "feature",
                "clip_outlier": True
            }
        },
        {
            "class": "Fillna",
            "kwargs": {
                "fields_group": "feature"
            }
        }
    ],
    "learn_processors": [
        {"class": "DropnaLabel"},
        {
            "class": "CSRankNorm",
            "kwargs": {
                "fields_group": "label"
            }
        }
    ],
    "label": ["Ref($close, -5) / Ref($close, -1) - 1"]
}

market_data_handler_config = {
    "start_time": "2008-01-01",
    "end_time": "2020-08-01",
    "fit_start_time": "2008-01-01",
    "fit_end_time": "2014-12-31",
    "instruments": market,
    "infer_processors": [
        {
            "class": "RobustZScoreNorm",
            "kwargs": {
                "fields_group": "feature",
                "clip_outlier": True
            }
        },
        {
            "class": "Fillna",
            "kwargs": {
                "fields_group": "feature"
            }
        }
    ]
}

# 模型配置
model_config = {
    "class": "MASTERModel",
    "module_path": "qlib.contrib.model.pytorch_master_ts",
    "kwargs": {
        "seed": 0,
        "n_epochs": 1,
        "lr": 0.000008,
        "train_stop_loss_thred": 0.95,
        "market": market,
        "benchmark": benchmark,
        "save_prefix": market
    }
}

# 数据集配置
dataset_config = {
    "class": "MASTERTSDatasetH",
    "module_path": "qlib.contrib.data.dataset",
    "kwargs": {
        "handler": {
            "class": "Alpha158",
            "module_path": "qlib.contrib.data.handler",
            "kwargs": data_handler_config
        },
        "segments": {
            "train": ["2008-01-01", "2014-12-31"],
            "valid": ["2015-01-01", "2016-12-31"],
            "test": ["2017-01-01", "2020-08-01"]
        },
        "step_len": 8,
        "market_data_handler_config": market_data_handler_config
    }
}

# 投资组合分析配置
port_analysis_config = {
    "strategy": {
        "class": "TopkDropoutStrategy",
        "module_path": "qlib.contrib.strategy",
        "kwargs": {
            "signal": "<PRED>",
            "topk": 30,
            "n_drop": 30
        }
    },
    "backtest": {
        "start_time": "2017-01-01",
        "end_time": "2020-08-01",
        "account": 100000000,
        "benchmark": benchmark,
        "exchange_kwargs": {
            "deal_price": "close"
        }
    }
}

## Creating Model and Dataset Instances

In [18]:
model = init_instance_by_config(model_config)
dataset = init_instance_by_config(dataset_config)


模型参数量统计:
--------------------------------------------------
特征门控层 (Gate): 10,112 参数
输入映射层 (x2y): 40,704 参数
时间注意力层 (TAttention): 329,216 参数
空间注意力层 (SAttention): 329,216 参数
时序注意力层 (TemporalAttention): 65,536 参数
解码器 (Decoder): 257 参数
--------------------------------------------------
总参数量: 775,041


[3908378:MainThread](2025-04-07 00:08:00,546) INFO - qlib.timer - [log.py:127] - Time cost: 11.463s | Loading data Done
  return function_base._ureduce(a, func=_nanmedian, keepdims=keepdims,
[3908378:MainThread](2025-04-07 00:08:05,505) INFO - qlib.timer - [log.py:127] - Time cost: 4.480s | RobustZScoreNorm Done
[3908378:MainThread](2025-04-07 00:08:06,106) INFO - qlib.timer - [log.py:127] - Time cost: 0.600s | Fillna Done
[3908378:MainThread](2025-04-07 00:08:06,502) INFO - qlib.timer - [log.py:127] - Time cost: 0.217s | DropnaLabel Done
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[cols] = t
[3908378:MainThread](2025-04-07 00:08:06,724) INFO - qlib.timer - [log.py:127] - Time cost: 0.220s | CSRankNorm Done
[3908378:MainThread](2025-04-07 00:08:06,726) INFO - qlib.ti

## Loading or Training Model

In [21]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"{market}-{model_config["class"]}-model"
model_file_path = Path(f"./model/{model_name}.pkl")

if not os.path.exists('./model'):
    os.makedirs('./model')
    
with R.start(experiment_name="train_model"):
    if not model_file_path.exists():
        R.log_params(**model_config["kwargs"])
        # print("文件列表:", list(R.get_recorder().list_artifacts()))
        
        # 方法1：使用 sys.stdout.write
        def custom_print(*args, **kwargs):
            msg = ' '.join(map(str, args)) + '\n'
            import sys
            sys.stdout.write(msg)
        
        # 临时替换 print
        import builtins
        orig_print = builtins.print
        builtins.print = custom_print
        
        try:
            model.fit(dataset)  # 训练模型
        finally:
            pkl_path = os.path.join("./model", f"{model_name}.pkl")
            with open(pkl_path, "wb") as f:
                pickle.dump(model, f)
            builtins.print = orig_print  # 确保恢复原始 print
        
        R.save_objects(trained_model=model)
        
    else:
        model.load_model(f"./model/{model_name}.pkl")
        R.save_objects(trained_model=model)
        
    rid = R.get_recorder().id

[3908378:MainThread](2025-04-07 00:09:05,376) INFO - qlib.workflow - [exp.py:258] - Experiment 566184983307789232 starts running ...
[3908378:MainThread](2025-04-07 00:09:05,434) INFO - qlib.workflow - [recorder.py:345] - Recorder 11c2d47286df4a929bdb1648d74150ed starts running under Experiment 566184983307789232 ...


[3908378:MainThread](2025-04-07 00:09:05,521) INFO - qlib.timer - [log.py:127] - Time cost: 0.004s | waiting `async_log` Done


ValueError: Model not found.

In [12]:
all_metrics = {
    k: []
    for k in [
        "IC",
        "ICIR",
        "Rank IC",
        "Rank ICIR",
        "1day.excess_return_without_cost.annualized_return",
        "1day.excess_return_without_cost.information_ratio",
    ]
}

In [14]:
# backtest and analysis
import numpy as np


print(f"[Status]: Model Training/ Loading finished".upper())
with R.start(experiment_name="backtest_analysis"):
    recorder = R.get_recorder(recorder_id=rid, experiment_name="train_model")
    model = recorder.load_object("trained_model")

    # prediction
    recorder = R.get_recorder()
    ba_rid = recorder.id
    sr = SignalRecord(model, dataset, recorder)
    sr.generate()
    
    # Signal Analysis
    sar = SigAnaRecord(recorder)
    sar.generate()

    # backtest & analysis
    par = PortAnaRecord(recorder, port_analysis_config, "day")
    par.generate()
    
    metrics = recorder.list_metrics()
    print(f"Metrics: {metrics}")
    for k in all_metrics.keys():
        all_metrics[k].append(metrics[k])
    print(f"All metrics: {all_metrics}")
    print(f"Available metrics: {metrics.keys()}")
    
for k in all_metrics.keys():
        print(f"{k}: {np.mean(all_metrics[k])} +- {np.std(all_metrics[k])}")

recorder = R.get_recorder(recorder_id=ba_rid, experiment_name="backtest_analysis")

[3908378:MainThread](2025-04-06 22:25:07,912) INFO - qlib.workflow - [exp.py:258] - Experiment 362767161650010529 starts running ...
[3908378:MainThread](2025-04-06 22:25:07,977) INFO - qlib.workflow - [recorder.py:345] - Recorder 8d84fab322dd47c782da32d0322d5857 starts running under Experiment 362767161650010529 ...


[STATUS]: MODEL TRAINING/ LOADING FINISHED


[3908378:MainThread](2025-04-06 22:25:23,084) INFO - qlib.workflow - [record_temp.py:198] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 362767161650010529


'The following are prediction results of the MASTERModel model.'
                              0
datetime   instrument          
2017-01-03 SH600000   -0.094200
           SH600008    0.099505
           SH600009    0.298842
           SH600010   -0.023810
           SH600015    0.067495


[3908378:MainThread](2025-04-06 22:25:23,819) INFO - qlib.backtest caller - [__init__.py:93] - Create new exchange


{'IC': 0.03648069232457443,
 'ICIR': 0.2675256965377252,
 'Rank IC': 0.04273447517543348,
 'Rank ICIR': 0.29957152985370383}




backtest loop:   0%|          | 0/871 [00:00<?, ?it/s]

  return np.nanmean(self.data)
  return np.nanmean(self.data)
  return np.nanmean(self.data)
[3908378:MainThread](2025-04-06 22:25:59,916) INFO - qlib.workflow - [record_temp.py:515] - Portfolio analysis record 'port_analysis_1day.pkl' has been saved as the artifact of the Experiment 362767161650010529
[3908378:MainThread](2025-04-06 22:25:59,922) INFO - qlib.workflow - [record_temp.py:540] - Indicator analysis record 'indicator_analysis_1day.pkl' has been saved as the artifact of the Experiment 362767161650010529


'The following are analysis results of benchmark return(1day).'
                       risk
mean               0.000477
std                0.012295
annualized_return  0.113561
information_ratio  0.598699
max_drawdown      -0.370479
'The following are analysis results of the excess return without cost(1day).'
                       risk
mean               0.000280
std                0.005396
annualized_return  0.066689
information_ratio  0.801159
max_drawdown      -0.119840
'The following are analysis results of the excess return with cost(1day).'
                       risk
mean              -0.001737
std                0.005399
annualized_return -0.413358
information_ratio -4.963131
max_drawdown      -1.526544
'The following are analysis results of indicators(1day).'
     value
ffr    1.0
pa     0.0
pos    0.0


[3908378:MainThread](2025-04-06 22:26:02,053) INFO - qlib.timer - [log.py:127] - Time cost: 0.001s | waiting `async_log` Done


Metrics: {'1day.ffr': 1.0, 'Rank ICIR': 0.29957152985370383, '1day.excess_return_with_cost.information_ratio': -4.963131106783359, '1day.excess_return_with_cost.annualized_return': -0.4133577531694415, '1day.excess_return_without_cost.information_ratio': 0.8011587221738969, '1day.pos': 0.0, 'IC': 0.03648069232457443, 'ICIR': 0.2675256965377252, '1day.pa': 0.0, '1day.excess_return_without_cost.std': 0.005395668816402809, 'Rank IC': 0.04273447517543348, '1day.excess_return_with_cost.mean': -0.001736797282224544, '1day.excess_return_without_cost.mean': 0.000280204671652662, '1day.excess_return_with_cost.max_drawdown': -1.5265442023044526, '1day.excess_return_without_cost.max_drawdown': -0.11983952371055057, '1day.excess_return_with_cost.std': 0.005398608841853511, '1day.excess_return_without_cost.annualized_return': 0.06668871185333357}
All metrics: {'IC': [0.03648069232457443], 'ICIR': [0.2675256965377252], 'Rank IC': [0.04273447517543348], 'Rank ICIR': [0.29957152985370383], '1day.exces

In [7]:
pred_df = recorder.load_object("pred.pkl")

report_normal_df = recorder.load_object("portfolio_analysis/report_normal_1day.pkl")

positions = recorder.load_object("portfolio_analysis/positions_normal_1day.pkl")

analysis_df = recorder.load_object("portfolio_analysis/port_analysis_1day.pkl")

In [None]:
from qlib.contrib.report import analysis_model, analysis_position
import pandas as pd

save_dir = "figure_results"
os.makedirs(save_dir, exist_ok=True)

figs = analysis_position.report_graph(report_normal_df, show_notebook=False)
if not figs:
    raise ValueError("No figures were generated by `report_graph`. Please check the input data.")

for i, _fig in enumerate(figs):
    fig_path = f"{save_dir}/报告图表{i}.png"
    try:
        _fig.write_image(fig_path)
        print(f"Saved figure {i} to {fig_path}")
    except Exception as e:
        print(f"Error saving figure {i}: {e}")

figs = analysis_position.risk_analysis_graph(analysis_df, report_normal_df, show_notebook=False)
if not figs:
    raise ValueError("No figures were generated by `risk_analysis_graph`. Please check the input data.")

for i, _fig in enumerate(figs):
    fig_path = f"{save_dir}/风险分析图表{i}.png"
    try:
        _fig.write_image(fig_path)
        print(f"Saved figure {i} to {fig_path}")
    except Exception as e:
        print(f"Error saving figure {i}: {e}")

# Step 1: Retrieve TSDataSampler and extract the underlying DataFrame
label_sampler = dataset.prepare(segments="test", col_set="label", only_label=True)
label_df = label_sampler.idx_df  # Extract the DataFrame

# Step 2: Reshape label_df to long format
label_df = label_df.reset_index()  # Include datetime as a column
label_df_long = pd.melt(
    label_df,
    id_vars=["datetime"],  # Keep datetime as is
    var_name="instrument",  # Former column names become instrument names
    value_name="label"  # Values in the DataFrame become the label column
)
label_df_long = label_df_long.dropna(subset=["label"])  # Drop NaN labels

# Debugging: Check the reshaped DataFrame
print("Reshaped label DataFrame:")
print(label_df_long.head())

# Step 3: Combine with predictions
# Ensure pred_df is in the long format with columns: datetime, instrument, prediction
pred_label = pd.merge(label_df_long, pred_df, on=["datetime", "instrument"], how="inner")

# Step 4: Rename the prediction column to 'score'
pred_label = pred_label.rename(columns={0: "score"})

# Drop rows with NaNs in the `label` or `score` columns
pred_label = pred_label.dropna(subset=["label", "score"])

# Verify there are no NaNs remaining
print(pred_label.isna().sum())

# Convert `label` and `score` columns to numeric, coercing errors to NaN
pred_label["label"] = pd.to_numeric(pred_label["label"], errors="coerce")
pred_label["score"] = pd.to_numeric(pred_label["score"], errors="coerce")

# Step 5: Set datetime and instrument as a multi-level index
pred_label = pred_label.set_index(["datetime", "instrument"])  # Set the multi-level index

# Verify the final structure

figs = analysis_position.score_ic_graph(pred_label, show_notebook=False)
for i, _fig in enumerate(figs):
    fig_path = f"{save_dir}/IC图表{i}.png"
    try:
        _fig.write_image(fig_path)
        print(f"Saved figure {i} to {fig_path}")
    except Exception as e:
        print(f"Error saving figure {i}: {e}")

figs = analysis_model.model_performance_graph(pred_label, show_notebook=False)
for i, _fig in enumerate(figs):
    fig_path = f"{save_dir}/模型性能图表{i}.png"
    try:
        _fig.write_image(fig_path)
        print(f"Saved figure {i} to {fig_path}")
    except Exception as e:
        print(f"Error saving figures {i}: {e}")


print(f"[Status]: Mission Completed".upper())

Saved figure 0 to figure_results/报告图表0.png












Saved figure 0 to figure_results/风险分析图表0.png
Saved figure 1 to figure_results/风险分析图表1.png
Saved figure 2 to figure_results/风险分析图表2.png
Saved figure 3 to figure_results/风险分析图表3.png
Saved figure 4 to figure_results/风险分析图表4.png
Reshaped label DataFrame:
    datetime instrument label
0 2016-12-21   SH600000     0
1 2016-12-22   SH600000   300
2 2016-12-23   SH600000   600
3 2016-12-26   SH600000   900
4 2016-12-27   SH600000  1200
datetime      0
instrument    0
label         0
score         0
dtype: int64
Saved figure 0 to figure_results/IC图表0.png



'M' is deprecated and will be removed in a future version, please use 'ME' instead.



Saved figure 0 to figure_results/模型性能图表0.png
Saved figure 1 to figure_results/模型性能图表1.png
Saved figure 2 to figure_results/模型性能图表2.png
Saved figure 3 to figure_results/模型性能图表3.png
Saved figure 4 to figure_results/模型性能图表4.png
Saved figure 5 to figure_results/模型性能图表5.png
[STATUS]: MISSION COMPLETED
