In [1]:
import json
import os
from typing import Optional, Tuple, List
from datetime import datetime
from pathlib import Path
from openai import OpenAI
import fire
import pandas as pd

import numpy as np
from sb3_contrib.ppo_mask import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback

from alphagen.data.expression import *
from alphagen.data.parser import ExpressionParser
from alphagen.models.linear_alpha_pool import LinearAlphaPool, MseAlphaPool
from alphagen.rl.env.wrapper import AlphaEnv
from alphagen.rl.policy import LSTMSharedNet
from alphagen.utils import reseed_everything, get_logger
from alphagen.rl.env.core import AlphaEnvCore
from alphagen_qlib.calculator import QLibStockDataCalculator
from alphagen_qlib.stock_data import initialize_qlib
from alphagen_llm.client import ChatClient, OpenAIClient, ChatConfig
from alphagen_llm.prompts.system_prompt import EXPLAIN_WITH_TEXT_DESC
from alphagen_llm.prompts.interaction import InterativeSession, DefaultInteraction

In [2]:
instruments: str = "csi300"
device = torch.device("cuda:0")


def get_dataset(start: str, end: str) -> StockData:
    return StockData(
        instrument=instruments,
        start_time=start,
        end_time=end,
        device=device
    )

segments = [
    ("2012-01-01", "2019-12-31"),
    ("2022-01-01", "2022-06-30"),
    ("2022-07-01", "2022-12-31"),
    ("2023-01-01", "2023-06-30")
]


datasets = [get_dataset(*s) for s in segments]

[23512:MainThread](2025-04-24 12:33:13,832) INFO - qlib.Initialization - [config.py:420] - default_conf: client.
[23512:MainThread](2025-04-24 12:33:14,950) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[23512:MainThread](2025-04-24 12:33:14,952) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': WindowsPath('C:/Users/tywat/.qlib/qlib_data/cn_data')}


In [3]:
close = Feature(FeatureType.CLOSE)
target = Ref(close, -20) / close - 1
calculators = [QLibStockDataCalculator(d, target) for d in datasets]

In [4]:
from alphagen.data.expression import Operators
from alphagen.data.parser import ExpressionParser

def load_linear_alpha_pool_from_json(json_path: str, 
                                     calculator: QLibStockDataCalculator,
                                     single_alpha: bool = False) -> LinearAlphaPool | list[LinearAlphaPool]:
    # Load the JSON file
    parser = ExpressionParser(Operators)
    with open(json_path, 'r') as f:
        pool_data = json.load(f)

    # Extract expressions and weights from the loaded data
    expressions = pool_data['exprs']
    weights = pool_data['weights']

    # Create an instance of LinearAlphaPool
    alpha_pool = MseAlphaPool(
        capacity=len(expressions),  # Set the capacity based on the number of expressions
        calculator=calculator
    )

    # Load the expressions into the pool
    expres = []
    if single_alpha:
        alpha_pools = []

        for expression,weight in zip(expressions,weights):
            alpha_pool = MseAlphaPool(
                capacity=1,
                calculator=calculator
                )
            expre = parser.parse(expression)
            alpha_pool.force_load_exprs([expre], [weight])
            alpha_pools.append(alpha_pool)

        return  alpha_pools
    else:
        for expression in expressions:
            expre = parser.parse(expression)
            expres.append(expre)
        
        
        alpha_pool.force_load_exprs(expres, weights)

        return alpha_pool

alpha_pools = load_linear_alpha_pool_from_json('out/results/csi300_20_0_20250208124320_rl/251904_steps_pool.json', calculators[1])
alpha_pool = load_linear_alpha_pool_from_json('out/results/csi300_20_0_20250208124320_rl/251904_steps_pool.json', calculators[1], single_alpha=True)

In [5]:
ic_value, rank_ic_value = alpha_pools.test_ensemble(calculators[2])
print(alpha_pools.exprs)
print(ic_value, rank_ic_value)

[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-2.0), Delta(Log($vwap),1d), Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($high,-30.0),40d)),40d),-0.01)), Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(Log($low),5d)),30.0)))), Mad(Add(2.0,Mean($vwap,20d)),10d), Corr($close,$low,10d), Abs(Log(Mad(Sub(-0.5,$close),20d))), Mad(Log(Log($volume)),40d), Mul(0.5,Corr(Log($volume),WMA(Log($volume),40d),40d)), Mul(Mul($volume,Mul(Add(Mean($high,20d),30.0),$high)),0.5), Mul(WMA(Log(Abs(Var($low,5d))),20d),-2.0), Abs(Mul(5.0,Sub($open,30.0))), Mean(Less(Sub(-2.0,Corr($volume,$high,20d)),1.0),10d), Sub(Less(1.0,$low),5.0), Add(Corr(Sub(-1.0,$high),$volume,10d),0.01), WMA(Div(Std(WMA(Div(Div($vwap,30.0),$low),40d),20d),-5.0),10d), WMA(Sub(-1.0,Div($low,$close)),20d), Less(Div($close,$vwap),$volume), Sub(Mad(Mean(Log($low),20d),40d),5.0), None]
0.06614601612091064 0.0644562840461731


In [6]:
alpha_index = 3

ic_value, rank_ic_value = alpha_pool[alpha_index].test_ensemble(calculators[2])
print(alpha_pool[alpha_index].exprs)
print(ic_value, rank_ic_value)

[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), None]
0.010267447680234909 0.010892813093960285


In [7]:
for alpha in alpha_pool:
    print(alpha.exprs)

[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-2.0), None]
[Delta(Log($vwap),1d), None]
[Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($high,-30.0),40d)),40d),-0.01)), None]
[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), None]
[Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(Log($low),5d)),30.0)))), None]
[Mad(Add(2.0,Mean($vwap,20d)),10d), None]
[Corr($close,$low,10d), None]
[Abs(Log(Mad(Sub(-0.5,$close),20d))), None]
[Mad(Log(Log($volume)),40d), None]
[Mul(0.5,Corr(Log($volume),WMA(Log($volume),40d),40d)), None]
[Mul(Mul($volume,Mul(Add(Mean($high,20d),30.0),$high)),0.5), None]
[Mul(WMA(Log(Abs(Var($low,5d))),20d),-2.0), None]
[Abs(Mul(5.0,Sub($open,30.0))), None]
[Mean(Less(Sub(-2.0,Corr($volume,$high,20d)),1.0),10d), None]
[Sub(Less(1.0,$low),5.0), None]
[Add(Corr(Sub(-1.0,$high),$volume,10d),0.01), None]
[WMA(Div(Std(WMA(Div(Div($vwap,30.0),$low),40d),20d),-5.0),10d), None]
[WMA(Sub(-1.0,Div($low,$close)),20d), None]
[Less(Div($close,$vwap),$volume), None]
[Sub(Mad(Mean(Log($low),

In [8]:
ics = []
rank_ics = []
alphas = []

for alpha in alpha_pool:
    ic_value, rank_ic_value = alpha.test_ensemble(calculators[2])

    ics.append(ic_value)
    rank_ics.append(rank_ic_value)
    alphas.append(alpha.exprs)

df_ic_ind = pd.DataFrame({'alpha': alphas, 'ic': ics, 'rank_ic': rank_ics})
df_ic_ind

Unnamed: 0,alpha,ic,rank_ic
0,"[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-...",0.055708,0.084874
1,"[Delta(Log($vwap),1d), None]",-0.02456,-0.012933
2,"[Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($h...",-0.036502,-0.035545
3,"[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0...",0.010267,0.010893
4,"[Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(...",-0.061593,-0.092103
5,"[Mad(Add(2.0,Mean($vwap,20d)),10d), None]",-0.006636,0.043511
6,"[Corr($close,$low,10d), None]",0.056122,0.063845
7,"[Abs(Log(Mad(Sub(-0.5,$close),20d))), None]",-0.081099,-0.097603
8,"[Mad(Log(Log($volume)),40d), None]",-0.025405,-0.042726
9,"[Mul(0.5,Corr(Log($volume),WMA(Log($volume),40...",-0.040783,-0.046877


In [9]:
for p in Path("out/gp").iterdir():
    seed = int(p.name)

with open(p / "40.json") as f:
    report = json.load(f)


state = report["res"]["res"]["pool_state"]
state["exprs"]

['Std(EMA(Min(Mul(5.0,$high),30d),50d),10d)',
 'Std(EMA(Min(Log($vwap),30d),40d),10d)',
 'Sum(Mean(Abs(Corr($low,$high,20d)),40d),20d)',
 'Std(Med(Min(Mul(5.0,$high),30d),10d),10d)',
 'Mad(Min($low,20d),20d)',
 'Std(Std(Min(Mul($vwap,2.0),30d),10d),10d)',
 'Mad(Med($close,50d),10d)',
 'Std(Cov(Corr(Var($volume,40d),$high,20d),$close,30d),10d)',
 'Std(Med(Ref(Mul(10.0,$high),30d),10d),10d)',
 'Std(Min(Sum(Mul(10.0,$high),40d),50d),10d)',
 'Mad(Min($high,30d),10d)',
 'Std(Std(Med(Mul(0.5,$high),20d),20d),10d)',
 'Mad(Ref(Min($high,30d),10d),10d)',
 'Std(Abs(WMA(Cov(0.01,$high,50d),10d)),10d)',
 'Std(Min(WMA($high,20d),50d),10d)',
 'Log(Var(Sum($low,30d),40d))',
 'Std(EMA(Min(Mul(Std($high,10d),$high),30d),50d),10d)',
 'Std(Max(Min(Mul(5.0,$high),30d),20d),10d)',
 'Std(Min(Mean(Corr(5.0,$high,30d),50d),10d),10d)',
 'Std(EMA(WMA($vwap,10d),40d),10d)']

# main

In [2]:
ex_num = "51-5"

In [3]:
import pickle

file_path = f'out/backtests/{ex_num}/gp/2-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [4]:
import pickle

file_path = f'out/backtests/{ex_num}/gp/2-report.pkl'

with open(file_path, 'rb') as file:
    gp_report = pickle.load(file)
gp_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-04,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,-0.004575
2022-01-05,9.985801e+07,-1.932494e-16,9.465720e+07,0.946572,1.419858e+05,0.001420,9.465720e+07,5.200816e+06,-0.010096
2022-01-06,9.877754e+07,-1.057961e-02,1.106686e+08,0.160341,1.660029e+05,0.000241,9.822515e+07,5.523929e+05,-0.010248
2022-01-07,9.805608e+07,-7.140884e-03,1.214039e+08,0.108682,1.821058e+05,0.000163,9.776123e+07,2.948506e+05,0.000859
2022-01-10,9.829718e+07,2.678393e-03,1.357596e+08,0.146404,2.036395e+05,0.000220,9.792557e+07,3.716071e+05,0.004496
...,...,...,...,...,...,...,...,...,...
2023-06-26,8.111740e+07,-1.395978e-02,3.850987e+09,0.197453,5.776481e+06,0.000296,8.069213e+07,4.252634e+05,-0.014060
2023-06-27,8.159544e+07,6.206125e-03,3.867907e+09,0.208583,5.801860e+06,0.000313,8.113946e+07,4.559780e+05,0.009379
2023-06-28,8.122846e+07,-4.247665e-03,3.881505e+09,0.166656,5.822258e+06,0.000250,8.086546e+07,3.629927e+05,-0.001204
2023-06-29,8.091265e+07,-3.631036e-03,3.895410e+09,0.171183,5.843115e+06,0.000257,8.054653e+07,3.661228e+05,-0.004937


In [5]:
import pickle

file_path = f'out/backtests/{ex_num}/rl/0-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [6]:
import pickle

file_path = f'out/backtests/{ex_num}/rl/0-report.pkl'

with open(file_path, 'rb') as file:
    alphaGen_report = pickle.load(file)
alphaGen_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-04,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,-0.004575
2022-01-05,9.985790e+07,-1.495937e-16,9.473230e+07,0.947323,1.420985e+05,0.001421,9.473230e+07,5.125597e+06,-0.010096
2022-01-06,9.867374e+07,-1.151152e-02,1.178293e+08,0.231299,1.767440e+05,0.000347,9.795229e+07,7.214464e+05,-0.010248
2022-01-07,9.792732e+07,-7.249342e-03,1.385661e+08,0.210155,2.078491e+05,0.000315,9.736150e+07,5.658137e+05,0.000859
2022-01-10,9.838605e+07,5.008039e-03,1.596947e+08,0.215758,2.395420e+05,0.000324,9.767537e+07,7.106754e+05,0.004496
...,...,...,...,...,...,...,...,...,...
2023-06-26,7.993807e+07,-1.078031e-02,6.425894e+09,0.200286,9.638841e+06,0.000300,7.950136e+07,4.367164e+05,-0.014060
2023-06-27,8.036758e+07,5.665147e-03,6.441465e+09,0.194791,9.662197e+06,0.000292,7.996248e+07,4.050960e+05,0.009379
2023-06-28,8.035523e+07,1.592819e-04,6.458232e+09,0.208624,9.687347e+06,0.000313,7.991798e+07,4.372509e+05,-0.001204
2023-06-29,7.985099e+07,-5.961809e-03,6.475015e+09,0.208864,9.712522e+06,0.000313,7.940481e+07,4.461809e+05,-0.004937


In [7]:
alphaGen_report["cum_return"] = alphaGen_report["return"].cumsum()
alphaGen_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench,cum_return
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2022-01-04,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,-0.004575,0.000000e+00
2022-01-05,9.985790e+07,-1.495937e-16,9.473230e+07,0.947323,1.420985e+05,0.001421,9.473230e+07,5.125597e+06,-0.010096,-1.495937e-16
2022-01-06,9.867374e+07,-1.151152e-02,1.178293e+08,0.231299,1.767440e+05,0.000347,9.795229e+07,7.214464e+05,-0.010248,-1.151152e-02
2022-01-07,9.792732e+07,-7.249342e-03,1.385661e+08,0.210155,2.078491e+05,0.000315,9.736150e+07,5.658137e+05,0.000859,-1.876086e-02
2022-01-10,9.838605e+07,5.008039e-03,1.596947e+08,0.215758,2.395420e+05,0.000324,9.767537e+07,7.106754e+05,0.004496,-1.375282e-02
...,...,...,...,...,...,...,...,...,...,...
2023-06-26,7.993807e+07,-1.078031e-02,6.425894e+09,0.200286,9.638841e+06,0.000300,7.950136e+07,4.367164e+05,-0.014060,-8.167115e-02
2023-06-27,8.036758e+07,5.665147e-03,6.441465e+09,0.194791,9.662197e+06,0.000292,7.996248e+07,4.050960e+05,0.009379,-7.600601e-02
2023-06-28,8.035523e+07,1.592819e-04,6.458232e+09,0.208624,9.687347e+06,0.000313,7.991798e+07,4.372509e+05,-0.001204,-7.584672e-02
2023-06-29,7.985099e+07,-5.961809e-03,6.475015e+09,0.208864,9.712522e+06,0.000313,7.940481e+07,4.461809e+05,-0.004937,-8.180853e-02


In [8]:
import pickle

file_path = f'out/backtests/{ex_num}/boot/0-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [9]:
import pickle

file_path = f'out/backtests/{ex_num}/boot/0-report.pkl'

with open(file_path, 'rb') as file:
    boot_report = pickle.load(file)
boot_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-04,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,-0.004575
2022-01-05,9.985784e+07,-4.685717e-17,9.477453e+07,0.947745,1.421618e+05,0.001422,9.477453e+07,5.083310e+06,-0.010096
2022-01-06,9.931892e+07,-5.053295e-03,1.176492e+08,0.229073,1.764738e+05,0.000344,9.858742e+07,7.314989e+05,-0.010248
2022-01-07,9.879635e+07,-4.948388e-03,1.383807e+08,0.208736,2.075711e+05,0.000313,9.824187e+07,5.544813e+05,0.000859
2022-01-10,9.898657e+07,2.240015e-03,1.591064e+08,0.209782,2.386596e+05,0.000315,9.843016e+07,5.564034e+05,0.004496
...,...,...,...,...,...,...,...,...,...
2023-06-26,7.849863e+07,-8.309499e-03,6.148887e+09,0.198578,9.223330e+06,0.000298,7.808066e+07,4.179763e+05,-0.014060
2023-06-27,7.931224e+07,1.066094e-02,6.164393e+09,0.197535,9.246589e+06,0.000296,7.891114e+07,4.011036e+05,0.009379
2023-06-28,7.959322e+07,3.844564e-03,6.180356e+09,0.201266,9.270534e+06,0.000302,7.917715e+07,4.160684e+05,-0.001204
2023-06-29,7.895593e+07,-7.716007e-03,6.195784e+09,0.193845,9.293677e+06,0.000291,7.854235e+07,4.135834e+05,-0.004937


In [10]:
import pickle

file_path = f'out/backtests/{ex_num}/mcts/0-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [11]:
import pickle

file_path = f'out/backtests/{ex_num}/mcts/0-report.pkl'

with open(file_path, 'rb') as file:
    riskminer_report = pickle.load(file)
riskminer_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-04,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,-0.004575
2022-01-05,9.985787e+07,-1.396984e-16,9.475190e+07,0.947519,1.421278e+05,0.001421,9.475190e+07,5.105976e+06,-0.010096
2022-01-06,9.836685e+07,-1.458778e-02,1.176290e+08,0.229097,1.764435e+05,0.000344,9.764311e+07,7.237414e+05,-0.010248
2022-01-07,9.774858e+07,-5.996342e-03,1.365833e+08,0.192689,2.048749e+05,0.000289,9.724949e+07,4.990902e+05,0.000859
2022-01-10,9.761100e+07,-1.089328e-03,1.573129e+08,0.212071,2.359693e+05,0.000318,9.705353e+07,5.574702e+05,0.004496
...,...,...,...,...,...,...,...,...,...
2023-06-26,7.768480e+07,-1.097149e-02,6.307066e+09,0.201749,9.460598e+06,0.000303,7.726172e+07,4.230821e+05,-0.014060
2023-06-27,7.808161e+07,5.415937e-03,6.323021e+09,0.205386,9.484531e+06,0.000308,7.766038e+07,4.212311e+05,0.009379
2023-06-28,7.828846e+07,2.959400e-03,6.339170e+09,0.206829,9.508756e+06,0.000310,7.786032e+07,4.281329e+05,-0.001204
2023-06-29,7.726493e+07,-1.277198e-02,6.354926e+09,0.201246,9.532388e+06,0.000302,7.685421e+07,4.107186e+05,-0.004937


In [12]:
import pickle

file_path = f'out/backtests/{ex_num}/emcts/0-report.pkl'

with open(file_path, 'rb') as file:
    eminer_report = pickle.load(file)
eminer_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-04,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,-0.004575
2022-01-05,9.985789e+07,-5.078618e-16,9.474326e+07,0.947433,1.421149e+05,0.001421,9.474326e+07,5.114629e+06,-0.010096
2022-01-06,9.901192e+07,-8.232953e-03,1.106339e+08,0.159133,1.659509e+05,0.000239,9.845875e+07,5.531732e+05,-0.010248
2022-01-07,9.841586e+07,-5.731825e-03,1.296672e+08,0.192232,1.945007e+05,0.000288,9.788524e+07,5.306121e+05,0.000859
2022-01-10,9.887094e+07,4.913160e-03,1.486307e+08,0.192687,2.229460e+05,0.000289,9.837464e+07,4.963039e+05,0.004496
...,...,...,...,...,...,...,...,...,...
2023-06-26,8.205441e+07,-7.219266e-03,6.155745e+09,0.198015,9.233618e+06,0.000297,8.160489e+07,4.495254e+05,-0.014060
2023-06-27,8.284422e+07,9.813522e-03,6.166035e+09,0.125400,9.249052e+06,0.000188,8.257281e+07,2.714065e+05,0.009379
2023-06-28,8.263715e+07,-2.324510e-03,6.175699e+09,0.116659,9.263549e+06,0.000175,8.238715e+07,2.500000e+05,-0.001204
2023-06-29,8.238238e+07,-2.791651e-03,6.191750e+09,0.194233,9.287625e+06,0.000291,8.195627e+07,4.261092e+05,-0.004937


In [13]:
import pickle

file_path = f'out/backtests/{ex_num}/oracle/0-report.pkl'

with open(file_path, 'rb') as file:
    oracle_report = pickle.load(file)
oracle_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-04,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,-0.004575
2022-01-05,9.985764e+07,-1.085573e-16,9.490754e+07,0.949075,1.423613e+05,0.001424,9.490754e+07,4.950096e+06,-0.010096
2022-01-06,9.916367e+07,-6.612844e-03,1.173255e+08,0.224499,1.759883e+05,0.000337,9.847566e+07,6.880104e+05,-0.010248
2022-01-07,9.999130e+07,8.647576e-03,1.372567e+08,0.200993,2.058851e+05,0.000301,9.946274e+07,5.285551e+05,0.000859
2022-01-10,1.008635e+08,9.021050e-03,1.571427e+08,0.198877,2.357140e+05,0.000298,1.003431e+08,5.204199e+05,0.004496
...,...,...,...,...,...,...,...,...,...
2023-06-21,7.036718e+07,-2.494274e-02,5.595411e+09,0.190190,8.393116e+06,0.000285,7.001329e+07,3.538898e+05,-0.015343
2023-06-26,6.917242e+07,-1.668397e-02,5.609246e+09,0.196618,8.413870e+06,0.000295,6.880646e+07,3.659580e+05,-0.014060
2023-06-27,7.005104e+07,1.299443e-02,5.622738e+09,0.195038,8.434106e+06,0.000293,6.969155e+07,3.594831e+05,0.009379
2023-06-28,6.967217e+07,-5.089191e-03,5.637649e+09,0.212868,8.456474e+06,0.000319,6.928689e+07,3.852816e+05,-0.001204


In [14]:
import pandas as pd
df_com = pd.DataFrame()

df_com["GP"] = gp_report["return"].cumsum()
df_com["Alpha Gen"] = alphaGen_report["return"].cumsum()
df_com["Bootstrapped DQN"] = boot_report["return"].cumsum()
df_com["Oracle"] = oracle_report["return"].cumsum()
df_com["MCTS"] = riskminer_report["return"].cumsum()
df_com["EMCTS"] = eminer_report["return"].cumsum()
df_com["Benchmark"] = boot_report["bench"].cumsum()


df_com.head()

Unnamed: 0_level_0,GP,Alpha Gen,Bootstrapped DQN,Oracle,MCTS,EMCTS,Benchmark
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2022-01-04,0.0,0.0,0.0,0.0,0.0,0.0,-0.004575
2022-01-05,-1.932494e-16,-1.495937e-16,-4.685717e-17,-1.085573e-16,-1.396984e-16,-5.078618e-16,-0.014671
2022-01-06,-0.01057961,-0.01151152,-0.005053295,-0.006612844,-0.01458778,-0.008232953,-0.024919
2022-01-07,-0.0177205,-0.01876086,-0.01000168,0.002034732,-0.02058412,-0.01396478,-0.02406
2022-01-10,-0.0150421,-0.01375282,-0.007761668,0.01105578,-0.02167345,-0.009051618,-0.019564


In [15]:
import pickle

rmse_files = []
df_rmse = pd.DataFrame()

for model in ["boot","gp","rl","mcts","emcts"]:
    file_path = f'out/backtests/{ex_num}/{model}/0-rmse.pkl'

    if model == "gp":
        file_path = f'out/backtests/{ex_num}/{model}/2-rmse.pkl'

    with open(file_path, 'rb') as file:
        rmse = pickle.load(file)
    
    df_rmse[model] = rmse["rmse"]

df_rmse.rename(columns={"boot": "Bootstrapped DQN", "gp": "GP", "rl": "Alpha Gen", "mcts":"RiskMiner", "emcts":"EMCTS"}, inplace=True)
df_rmse

Unnamed: 0_level_0,Bootstrapped DQN,GP,Alpha Gen,RiskMiner,EMCTS
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2022-01-04,167.555522,177.514336,168.887766,172.971212,160.120509
2022-01-05,153.699473,157.663870,161.421907,159.527285,156.557074
2022-01-06,160.441807,177.222474,166.946331,173.892005,162.415278
2022-01-07,153.615997,164.946472,155.676237,157.988903,154.546200
2022-01-10,165.314948,168.037218,169.726879,173.577764,168.801580
...,...,...,...,...,...
2023-06-26,157.765002,161.253701,153.123041,159.269853,159.360967
2023-06-27,141.666193,165.917080,148.172313,150.018681,149.135333
2023-06-28,165.064869,159.339544,170.727022,166.705318,161.005124
2023-06-29,146.962941,158.124217,152.987266,154.433207,150.961112


In [16]:
df_rmse_ma = df_rmse.rolling(30).mean()

In [17]:
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=2, cols=1,subplot_titles=["Cumulative Return", "RMSE"])


for col in df_com.columns:
    fig.add_trace(
        go.Scatter(
            x=df_com.index,   
            y=df_com[col],    
            mode='lines',
            name=col,
            legendgroup = '1',
        ),
        row=1, 
        col=1
    )

for col in df_rmse_ma.columns:
    fig.add_trace(
        go.Scatter(
            x=df_rmse_ma.index,
            y=df_rmse_ma[col],    
            mode='lines',
            name=col,
            legendgroup = '2',
        ),
        row=2, 
        col=1
    )

# Update the layout to add the title and template
fig.update_layout(
    template='seaborn',
    autosize=False,
    width=1200,
    height=1200,
    legend_tracegroupgap=580,
    legend_groupclick="toggleitem"
)

fig.show()


# fig = px.line(df_com, y=["Bootstrapped DQN","Alpha Gen","GP","Benchmark"], 
#             #   x="lifeExp", 
#               title='Cumulative Return',
#               template="seaborn",
#               )
# fig.show()

# Test

In [72]:
from alphagen_qlib.stock_data import StockData

data = StockData(
        instrument="csi300",
        start_time="2020-01-01",
        end_time="2022-01-01"
    )
data

<alphagen_qlib.stock_data.StockData at 0x17e97367da0>

In [None]:
from qlib.data import D

instruments = data.stock_ids.tolist()

# Determine the proper start and end times for fetching price data.
# Here we use the same dates as in your StockData instance.
start_time = data._dates[data.max_backtrack_days].strftime("%Y-%m-%d")
end_time = data._dates[-data.max_future_days - 1].strftime("%Y-%m-%d")

# Query Qlib to get the closing price for each instrument.
# The field '$close' is used here (adjust if your field naming is different)
price_df = D.features(
    instruments=instruments,
    fields=["$close"],
    start_time="2020-01-01",
    end_time="2022-01-01"
)

price_df = price_df.reorder_levels(order=[1, 0])
price_df

In [None]:
def compute_oracle_scores(price_df: pd.DataFrame) -> pd.DataFrame:
    # price_df is expected to be a MultiIndex DataFrame with (date, instrument)
    # Unstack to get dates as rows and instruments as columns
    price_unstacked = price_df.unstack(level=1)
    # Compute daily percentage returns and shift so that prediction on day t 
    # is compared with return from t to t+1
    oracle_signal = price_unstacked.pct_change().shift(-1)
    # Stack back to a MultiIndex DataFrame
    return oracle_signal.stack()


oracle_scores = compute_oracle_scores(price_df)
oracle_scores


The default fill_method='pad' in DataFrame.pct_change is deprecated and will be removed in a future version. Either fill in any non-leading NA values prior to calling pct_change or specify 'fill_method=None' to not fill NA values.





Unnamed: 0_level_0,Unnamed: 1_level_0,$close
datetime,instrument,Unnamed: 2_level_1
2020-01-02,SH600000,0.010406
2020-01-02,SH600004,-0.007961
2020-01-02,SH600009,-0.000765
2020-01-02,SH600010,0.008123
2020-01-02,SH600011,0.000000
...,...,...
2021-12-30,SZ300782,-0.024249
2021-12-30,SZ300866,-0.033949
2021-12-30,SZ300888,0.008192
2021-12-30,SZ300896,-0.012722


In [17]:
from alphagen_qlib.utils import load_alpha_pool_by_path

calc = QLibStockDataCalculator(data, None)

for p in Path("out/boot_dqn").iterdir():
        inst, size, seed, time, ver = p.name.split('_', 4)
        size, seed = int(size), int(seed)
        if inst != "csi300" or size != 20 or time < "20240923" or ver == "llm_d5":
            continue
        try:
            exprs, weights = load_alpha_pool_by_path(str(p / "249500_steps_pool.json"))
        except:
            continue

boot_score = data.make_dataframe(calc.make_ensemble_alpha(exprs, weights))
boot_score

Unnamed: 0_level_0,Unnamed: 1_level_0,0
datetime,instrument,Unnamed: 2_level_1
2020-01-02,SH600000,0.029499
2020-01-02,SH600004,0.065036
2020-01-02,SH600009,-0.070193
2020-01-02,SH600010,0.009689
2020-01-02,SH600011,0.194273
...,...,...
2021-12-31,SZ300782,-0.121243
2021-12-31,SZ300866,-0.022676
2021-12-31,SZ300888,0.029904
2021-12-31,SZ300896,-0.026693


In [18]:
def normalize_series(series: pd.Series) -> pd.Series:
    return (series - series.mean()) / series.std()

def rank_series_per_date(series: pd.Series) -> pd.Series:
    """
    Rank the series for each date (assumed to be the first level of the MultiIndex).
    The highest value is assigned rank 1.
    """
    return series.groupby(level=0).rank(ascending=False, method='min')
def compute_rmse_per_date(model_scores: pd.Series, oracle_scores: pd.Series) -> pd.DataFrame:
    """
    Compute the RMSE across stocks for each date.
    
    Parameters:
      model_scores: pd.Series with MultiIndex (date, instrument) containing your model's prediction scores.
      oracle_scores: pd.Series with MultiIndex (date, instrument) containing the oracle's prediction scores.
      
    Returns:
      A DataFrame with the date as the index and a column 'rmse' containing the RMSE for that date.
    """
    # normalize
    # model_scores = normalize_series(model_scores)
    # oracle_scores = normalize_series(oracle_scores)

    # rank the scores
    model_scores = rank_series_per_date(model_scores)
    oracle_scores = rank_series_per_date(oracle_scores)

    # Combine both series into one DataFrame
    df = pd.DataFrame({
        "model": model_scores,
        "oracle": oracle_scores
    })
    # Group by the date level. If your MultiIndex doesn't have names,
    # you can group by level=0 (assuming the first level is the date).
    rmse_series = df.groupby(level=0).apply(
        lambda group: np.sqrt(((group["oracle"] - group["model"]) ** 2).mean())
    )
    rmse_df = rmse_series.to_frame(name="rmse")
    # Ensure the index is named "date" (or adjust as needed)
    rmse_df.index.name = "date"
    return rmse_df

rmse_df = compute_rmse_per_date(boot_score.iloc[:,0], oracle_scores.iloc[:,0])
rmse_df

Unnamed: 0_level_0,rmse
date,Unnamed: 1_level_1
2020-01-02,175.373923
2020-01-03,172.032975
2020-01-06,158.534962
2020-01-07,143.087359
2020-01-08,158.674098
...,...
2021-12-27,165.069690
2021-12-28,162.478913
2021-12-29,165.498930
2021-12-30,164.577087


In [19]:
import plotly.express as px
fig = px.line(rmse_df, y=["rmse"], 
            #   x="lifeExp", 
              title='RMSE',
              template="seaborn",
              )
fig.show()