In [1]:
import json
import os
from typing import Optional, Tuple, List
from datetime import datetime
from pathlib import Path
from openai import OpenAI
import fire
import pandas as pd

import numpy as np
from sb3_contrib.ppo_mask import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback

from alphagen.data.expression import *
from alphagen.data.parser import ExpressionParser
from alphagen.models.linear_alpha_pool import LinearAlphaPool, MseAlphaPool
from alphagen.rl.env.wrapper import AlphaEnv
from alphagen.rl.policy import LSTMSharedNet
from alphagen.utils import reseed_everything, get_logger
from alphagen.rl.env.core import AlphaEnvCore
from alphagen_qlib.calculator import QLibStockDataCalculator
from alphagen_qlib.stock_data import initialize_qlib
from alphagen_llm.client import ChatClient, OpenAIClient, ChatConfig
from alphagen_llm.prompts.system_prompt import EXPLAIN_WITH_TEXT_DESC
from alphagen_llm.prompts.interaction import InterativeSession, DefaultInteraction

In [2]:
instruments: str = "csi300"
device = torch.device("cuda:0")


def get_dataset(start: str, end: str) -> StockData:
    return StockData(
        instrument=instruments,
        start_time=start,
        end_time=end,
        device=device
    )

segments = [
    ("2012-01-01", "2019-12-31"),
    ("2022-01-01", "2022-06-30"),
    ("2022-07-01", "2022-12-31"),
    ("2023-01-01", "2023-06-30")
]


datasets = [get_dataset(*s) for s in segments]

[23512:MainThread](2025-04-24 12:33:13,832) INFO - qlib.Initialization - [config.py:420] - default_conf: client.
[23512:MainThread](2025-04-24 12:33:14,950) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[23512:MainThread](2025-04-24 12:33:14,952) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': WindowsPath('C:/Users/tywat/.qlib/qlib_data/cn_data')}


In [3]:
close = Feature(FeatureType.CLOSE)
target = Ref(close, -20) / close - 1
calculators = [QLibStockDataCalculator(d, target) for d in datasets]

In [4]:
from alphagen.data.expression import Operators
from alphagen.data.parser import ExpressionParser

def load_linear_alpha_pool_from_json(json_path: str, 
                                     calculator: QLibStockDataCalculator,
                                     single_alpha: bool = False) -> LinearAlphaPool | list[LinearAlphaPool]:
    # Load the JSON file
    parser = ExpressionParser(Operators)
    with open(json_path, 'r') as f:
        pool_data = json.load(f)

    # Extract expressions and weights from the loaded data
    expressions = pool_data['exprs']
    weights = pool_data['weights']

    # Create an instance of LinearAlphaPool
    alpha_pool = MseAlphaPool(
        capacity=len(expressions),  # Set the capacity based on the number of expressions
        calculator=calculator
    )

    # Load the expressions into the pool
    expres = []
    if single_alpha:
        alpha_pools = []

        for expression,weight in zip(expressions,weights):
            alpha_pool = MseAlphaPool(
                capacity=1,
                calculator=calculator
                )
            expre = parser.parse(expression)
            alpha_pool.force_load_exprs([expre], [weight])
            alpha_pools.append(alpha_pool)

        return  alpha_pools
    else:
        for expression in expressions:
            expre = parser.parse(expression)
            expres.append(expre)
        
        
        alpha_pool.force_load_exprs(expres, weights)

        return alpha_pool

alpha_pools = load_linear_alpha_pool_from_json('out/results/csi300_20_0_20250208124320_rl/251904_steps_pool.json', calculators[1])
alpha_pool = load_linear_alpha_pool_from_json('out/results/csi300_20_0_20250208124320_rl/251904_steps_pool.json', calculators[1], single_alpha=True)

In [5]:
ic_value, rank_ic_value = alpha_pools.test_ensemble(calculators[2])
print(alpha_pools.exprs)
print(ic_value, rank_ic_value)

[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-2.0), Delta(Log($vwap),1d), Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($high,-30.0),40d)),40d),-0.01)), Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(Log($low),5d)),30.0)))), Mad(Add(2.0,Mean($vwap,20d)),10d), Corr($close,$low,10d), Abs(Log(Mad(Sub(-0.5,$close),20d))), Mad(Log(Log($volume)),40d), Mul(0.5,Corr(Log($volume),WMA(Log($volume),40d),40d)), Mul(Mul($volume,Mul(Add(Mean($high,20d),30.0),$high)),0.5), Mul(WMA(Log(Abs(Var($low,5d))),20d),-2.0), Abs(Mul(5.0,Sub($open,30.0))), Mean(Less(Sub(-2.0,Corr($volume,$high,20d)),1.0),10d), Sub(Less(1.0,$low),5.0), Add(Corr(Sub(-1.0,$high),$volume,10d),0.01), WMA(Div(Std(WMA(Div(Div($vwap,30.0),$low),40d),20d),-5.0),10d), WMA(Sub(-1.0,Div($low,$close)),20d), Less(Div($close,$vwap),$volume), Sub(Mad(Mean(Log($low),20d),40d),5.0), None]
0.06614601612091064 0.0644562840461731


In [6]:
alpha_index = 3

ic_value, rank_ic_value = alpha_pool[alpha_index].test_ensemble(calculators[2])
print(alpha_pool[alpha_index].exprs)
print(ic_value, rank_ic_value)

[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), None]
0.010267447680234909 0.010892813093960285


In [7]:
for alpha in alpha_pool:
    print(alpha.exprs)

[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-2.0), None]
[Delta(Log($vwap),1d), None]
[Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($high,-30.0),40d)),40d),-0.01)), None]
[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), None]
[Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(Log($low),5d)),30.0)))), None]
[Mad(Add(2.0,Mean($vwap,20d)),10d), None]
[Corr($close,$low,10d), None]
[Abs(Log(Mad(Sub(-0.5,$close),20d))), None]
[Mad(Log(Log($volume)),40d), None]
[Mul(0.5,Corr(Log($volume),WMA(Log($volume),40d),40d)), None]
[Mul(Mul($volume,Mul(Add(Mean($high,20d),30.0),$high)),0.5), None]
[Mul(WMA(Log(Abs(Var($low,5d))),20d),-2.0), None]
[Abs(Mul(5.0,Sub($open,30.0))), None]
[Mean(Less(Sub(-2.0,Corr($volume,$high,20d)),1.0),10d), None]
[Sub(Less(1.0,$low),5.0), None]
[Add(Corr(Sub(-1.0,$high),$volume,10d),0.01), None]
[WMA(Div(Std(WMA(Div(Div($vwap,30.0),$low),40d),20d),-5.0),10d), None]
[WMA(Sub(-1.0,Div($low,$close)),20d), None]
[Less(Div($close,$vwap),$volume), None]
[Sub(Mad(Mean(Log($low),

In [8]:
ics = []
rank_ics = []
alphas = []

for alpha in alpha_pool:
    ic_value, rank_ic_value = alpha.test_ensemble(calculators[2])

    ics.append(ic_value)
    rank_ics.append(rank_ic_value)
    alphas.append(alpha.exprs)

df_ic_ind = pd.DataFrame({'alpha': alphas, 'ic': ics, 'rank_ic': rank_ics})
df_ic_ind

Unnamed: 0,alpha,ic,rank_ic
0,"[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-...",0.055708,0.084874
1,"[Delta(Log($vwap),1d), None]",-0.02456,-0.012933
2,"[Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($h...",-0.036502,-0.035545
3,"[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0...",0.010267,0.010893
4,"[Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(...",-0.061593,-0.092103
5,"[Mad(Add(2.0,Mean($vwap,20d)),10d), None]",-0.006636,0.043511
6,"[Corr($close,$low,10d), None]",0.056122,0.063845
7,"[Abs(Log(Mad(Sub(-0.5,$close),20d))), None]",-0.081099,-0.097603
8,"[Mad(Log(Log($volume)),40d), None]",-0.025405,-0.042726
9,"[Mul(0.5,Corr(Log($volume),WMA(Log($volume),40...",-0.040783,-0.046877


In [9]:
for p in Path("out/gp").iterdir():
    seed = int(p.name)

with open(p / "40.json") as f:
    report = json.load(f)


state = report["res"]["res"]["pool_state"]
state["exprs"]

['Std(EMA(Min(Mul(5.0,$high),30d),50d),10d)',
 'Std(EMA(Min(Log($vwap),30d),40d),10d)',
 'Sum(Mean(Abs(Corr($low,$high,20d)),40d),20d)',
 'Std(Med(Min(Mul(5.0,$high),30d),10d),10d)',
 'Mad(Min($low,20d),20d)',
 'Std(Std(Min(Mul($vwap,2.0),30d),10d),10d)',
 'Mad(Med($close,50d),10d)',
 'Std(Cov(Corr(Var($volume,40d),$high,20d),$close,30d),10d)',
 'Std(Med(Ref(Mul(10.0,$high),30d),10d),10d)',
 'Std(Min(Sum(Mul(10.0,$high),40d),50d),10d)',
 'Mad(Min($high,30d),10d)',
 'Std(Std(Med(Mul(0.5,$high),20d),20d),10d)',
 'Mad(Ref(Min($high,30d),10d),10d)',
 'Std(Abs(WMA(Cov(0.01,$high,50d),10d)),10d)',
 'Std(Min(WMA($high,20d),50d),10d)',
 'Log(Var(Sum($low,30d),40d))',
 'Std(EMA(Min(Mul(Std($high,10d),$high),30d),50d),10d)',
 'Std(Max(Min(Mul(5.0,$high),30d),20d),10d)',
 'Std(Min(Mean(Corr(5.0,$high,30d),50d),10d),10d)',
 'Std(EMA(WMA($vwap,10d),40d),10d)']

# main

In [2]:
ex_num = "51-5"

In [3]:
import pickle

file_path = f'out/backtests/{ex_num}/gp/2-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [4]:
import pickle

file_path = f'out/backtests/{ex_num}/gp/2-report.pkl'

with open(file_path, 'rb') as file:
    gp_report = pickle.load(file)
gp_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-17,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.008579
2022-01-18,9.985792e+07,3.440073e-16,9.472152e+07,0.947215,1.420823e+05,0.001421,9.472152e+07,5.136396e+06,0.009664
2022-01-19,9.823109e+07,-1.605432e-02,1.105078e+08,0.158088,1.657617e+05,0.000237,9.770305e+07,5.280410e+05,-0.006850
2022-01-20,9.809313e+07,-1.173111e-03,1.256552e+08,0.154202,1.884829e+05,0.000231,9.768412e+07,4.090134e+05,0.009022
2022-01-21,9.662296e+07,-1.476015e-02,1.405203e+08,0.151540,2.107804e+05,0.000227,9.623456e+07,3.884015e+05,-0.009163
...,...,...,...,...,...,...,...,...,...
2023-12-25,6.882185e+07,6.959775e-03,4.937287e+09,0.076303,7.405930e+06,0.000114,6.868150e+07,1.403542e+05,0.003062
2023-12-26,6.827072e+07,-7.887862e-03,4.942801e+09,0.080129,7.414202e+06,0.000120,6.812927e+07,1.414534e+05,-0.006769
2023-12-27,6.820923e+07,-7.808258e-04,4.948257e+09,0.079921,7.422386e+06,0.000120,6.806690e+07,1.423272e+05,0.003480
2023-12-28,7.084888e+07,3.882441e-02,4.953944e+09,0.083370,7.430916e+06,0.000125,7.068959e+07,1.592930e+05,0.023433


In [5]:
import pickle

file_path = f'out/backtests/{ex_num}/rl/0-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [6]:
import pickle

file_path = f'out/backtests/{ex_num}/rl/0-report.pkl'

with open(file_path, 'rb') as file:
    alphaGen_report = pickle.load(file)
alphaGen_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-17,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.008579
2022-01-18,9.985790e+07,9.313226e-18,9.473071e+07,0.947307,1.420961e+05,0.001421,9.473071e+07,5.127197e+06,0.009664
2022-01-19,9.843702e+07,-1.388077e-02,1.179144e+08,0.232167,1.768716e+05,0.000348,9.770213e+07,7.348894e+05,-0.006850
2022-01-20,9.878546e+07,3.854546e-03,1.385763e+08,0.209900,2.078645e+05,0.000315,9.823097e+07,5.544945e+05,0.009022
2022-01-21,9.768470e+07,-1.079329e-02,1.616035e+08,0.233103,2.424053e+05,0.000350,9.708252e+07,6.021788e+05,-0.009163
...,...,...,...,...,...,...,...,...,...
2023-12-25,6.625953e+07,9.087229e-03,8.225611e+09,0.198388,1.233842e+07,0.000298,6.592145e+07,3.380756e+05,0.003062
2023-12-26,6.578737e+07,-6.835672e-03,8.238435e+09,0.193533,1.235765e+07,0.000290,6.545466e+07,3.327101e+05,-0.006769
2023-12-27,6.563969e+07,-1.939265e-03,8.251830e+09,0.203616,1.237775e+07,0.000305,6.529417e+07,3.455249e+05,0.003480
2023-12-28,6.812087e+07,3.810828e-02,8.265325e+09,0.205584,1.239799e+07,0.000308,6.775430e+07,3.665647e+05,0.023433


In [7]:
alphaGen_report["cum_return"] = alphaGen_report["return"].cumsum()
alphaGen_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench,cum_return
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2022-01-17,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.008579,0.000000e+00
2022-01-18,9.985790e+07,9.313226e-18,9.473071e+07,0.947307,1.420961e+05,0.001421,9.473071e+07,5.127197e+06,0.009664,9.313226e-18
2022-01-19,9.843702e+07,-1.388077e-02,1.179144e+08,0.232167,1.768716e+05,0.000348,9.770213e+07,7.348894e+05,-0.006850,-1.388077e-02
2022-01-20,9.878546e+07,3.854546e-03,1.385763e+08,0.209900,2.078645e+05,0.000315,9.823097e+07,5.544945e+05,0.009022,-1.002623e-02
2022-01-21,9.768470e+07,-1.079329e-02,1.616035e+08,0.233103,2.424053e+05,0.000350,9.708252e+07,6.021788e+05,-0.009163,-2.081952e-02
...,...,...,...,...,...,...,...,...,...,...
2023-12-25,6.625953e+07,9.087229e-03,8.225611e+09,0.198388,1.233842e+07,0.000298,6.592145e+07,3.380756e+05,0.003062,-2.295889e-01
2023-12-26,6.578737e+07,-6.835672e-03,8.238435e+09,0.193533,1.235765e+07,0.000290,6.545466e+07,3.327101e+05,-0.006769,-2.364246e-01
2023-12-27,6.563969e+07,-1.939265e-03,8.251830e+09,0.203616,1.237775e+07,0.000305,6.529417e+07,3.455249e+05,0.003480,-2.383639e-01
2023-12-28,6.812087e+07,3.810828e-02,8.265325e+09,0.205584,1.239799e+07,0.000308,6.775430e+07,3.665647e+05,0.023433,-2.002556e-01


In [8]:
import pickle

file_path = f'out/backtests/{ex_num}/boot/0-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [9]:
import pickle

file_path = f'out/backtests/{ex_num}/boot/0-report.pkl'

with open(file_path, 'rb') as file:
    boot_report = pickle.load(file)
boot_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-17,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.008579
2022-01-18,9.985777e+07,-6.737537e-16,9.482124e+07,0.948212,1.422319e+05,0.001422,9.482124e+07,5.036526e+06,0.009664
2022-01-19,9.933598e+07,-4.879632e-03,1.178322e+08,0.230437,1.767483e+05,0.000346,9.860550e+07,7.304778e+05,-0.006850
2022-01-20,9.940954e+07,1.056455e-03,1.387534e+08,0.210611,2.081301e+05,0.000316,9.886241e+07,5.471360e+05,0.009022
2022-01-21,9.835870e+07,-1.028530e-02,1.576792e+08,0.190382,2.365188e+05,0.000286,9.783618e+07,5.225181e+05,-0.009163
...,...,...,...,...,...,...,...,...,...
2023-12-25,7.477640e+07,7.261069e-03,8.103524e+09,0.206788,1.215529e+07,0.000310,7.437622e+07,4.001800e+05,0.003062
2023-12-26,7.405064e+07,-9.399896e-03,8.118773e+09,0.203930,1.217816e+07,0.000306,7.363963e+07,4.110062e+05,-0.006769
2023-12-27,7.422740e+07,2.692181e-03,8.133834e+09,0.203385,1.220075e+07,0.000305,7.383762e+07,3.897888e+05,0.003480
2023-12-28,7.603595e+07,2.466694e-02,8.148779e+09,0.201339,1.222317e+07,0.000302,7.564364e+07,3.923131e+05,0.023433


In [10]:
import pickle

file_path = f'out/backtests/{ex_num}/mcts/0-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [11]:
import pickle

file_path = f'out/backtests/{ex_num}/mcts/0-report.pkl'

with open(file_path, 'rb') as file:
    riskminer_report = pickle.load(file)
riskminer_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-17,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.008579
2022-01-18,9.985792e+07,5.276524e-16,9.471718e+07,0.947172,1.420758e+05,0.001421,9.471718e+07,5.140748e+06,0.009664
2022-01-19,9.830926e+07,-1.516030e-02,1.179089e+08,0.232247,1.768633e+05,0.000348,9.756510e+07,7.441643e+05,-0.006850
2022-01-20,9.808012e+07,-2.066984e-03,1.352020e+08,0.175906,2.028030e+05,0.000264,9.761245e+07,4.676710e+05,0.009022
2022-01-21,9.693093e+07,-1.142416e-02,1.543368e+08,0.195093,2.315051e+05,0.000293,9.642279e+07,5.081464e+05,-0.009163
...,...,...,...,...,...,...,...,...,...
2023-12-25,6.622625e+07,5.127617e-03,8.088594e+09,0.208115,1.213289e+07,0.000312,6.587274e+07,3.535089e+05,0.003062
2023-12-26,6.579185e+07,-6.262087e-03,8.101716e+09,0.198140,1.215257e+07,0.000297,6.545096e+07,3.408963e+05,-0.006769
2023-12-27,6.573416e+07,-5.747871e-04,8.114965e+09,0.201384,1.217245e+07,0.000302,6.538642e+07,3.477483e+05,0.003480
2023-12-28,6.824126e+07,3.845504e-02,8.128776e+09,0.210091,1.219316e+07,0.000315,6.787854e+07,3.627177e+05,0.023433


In [12]:
import pickle

file_path = f'out/backtests/{ex_num}/emcts/0-report.pkl'

with open(file_path, 'rb') as file:
    eminer_report = pickle.load(file)
eminer_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-17,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.008579
2022-01-18,9.985794e+07,-2.235174e-16,9.470822e+07,0.947082,1.420623e+05,0.001421,9.470822e+07,5.149720e+06,0.009664
2022-01-19,9.798705e+07,-1.838813e-02,1.178302e+08,0.231549,1.767454e+05,0.000347,9.722904e+07,7.580115e+05,-0.006850
2022-01-20,9.763399e+07,-3.262142e-03,1.401072e+08,0.227346,2.101608e+05,0.000341,9.704958e+07,5.844081e+05,0.009022
2022-01-21,9.647633e+07,-1.153851e-02,1.608503e+08,0.212457,2.412754e+05,0.000319,9.583127e+07,6.450545e+05,-0.009163
...,...,...,...,...,...,...,...,...,...
2023-12-25,7.030376e+07,8.316102e-03,8.280945e+09,0.216444,1.242142e+07,0.000325,6.990930e+07,3.944575e+05,0.003062
2023-12-26,6.961452e+07,-9.508043e-03,8.294809e+09,0.197200,1.244221e+07,0.000296,6.924558e+07,3.689319e+05,-0.006769
2023-12-27,6.953325e+07,-8.823275e-04,8.308040e+09,0.190066,1.246206e+07,0.000285,6.918707e+07,3.461713e+05,0.003480
2023-12-28,7.233961e+07,4.066272e-02,8.322074e+09,0.201832,1.248311e+07,0.000303,7.197557e+07,3.640308e+05,0.023433


In [13]:
import pickle

file_path = f'out/backtests/{ex_num}/oracle/0-report.pkl'

with open(file_path, 'rb') as file:
    oracle_report = pickle.load(file)
oracle_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2022-01-17,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.008579
2022-01-18,9.985769e+07,9.720679e-17,9.487265e+07,0.948727,1.423090e+05,0.001423,9.487265e+07,4.985040e+06,0.009664
2022-01-19,9.961194e+07,-2.123548e-03,1.173382e+08,0.224975,1.760072e+05,0.000337,9.891135e+07,7.005863e+05,-0.006850
2022-01-20,1.003097e+08,7.310458e-03,1.376503e+08,0.203913,2.064755e+05,0.000306,9.975109e+07,5.585863e+05,0.009022
2022-01-21,9.886452e+07,-1.411082e-02,1.574587e+08,0.197472,2.361880e+05,0.000296,9.832724e+07,5.372747e+05,-0.009163
...,...,...,...,...,...,...,...,...,...
2023-12-22,6.285156e+07,-1.217461e-03,7.306094e+09,0.195835,1.095914e+07,0.000294,6.252698e+07,3.245785e+05,0.001909
2023-12-25,6.311659e+07,4.491769e-03,7.317618e+09,0.183353,1.097643e+07,0.000275,6.281197e+07,3.046240e+05,0.003062
2023-12-26,6.290240e+07,-3.124861e-03,7.328923e+09,0.179119,1.099338e+07,0.000269,6.261231e+07,2.900887e+05,-0.006769
2023-12-27,6.296502e+07,1.302571e-03,7.341798e+09,0.204681,1.101270e+07,0.000307,6.259027e+07,3.747525e+05,0.003480


In [14]:
import pandas as pd
df_com = pd.DataFrame()

df_com["GP"] = gp_report["return"].cumsum()
df_com["Alpha Gen"] = alphaGen_report["return"].cumsum()
df_com["Bootstrapped DQN"] = boot_report["return"].cumsum()
df_com["Oracle"] = oracle_report["return"].cumsum()
df_com["MCTS"] = riskminer_report["return"].cumsum()
df_com["EMCTS"] = eminer_report["return"].cumsum()
df_com["Benchmark"] = boot_report["bench"].cumsum()


df_com.head()

Unnamed: 0_level_0,GP,Alpha Gen,Bootstrapped DQN,Oracle,MCTS,EMCTS,Benchmark
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2022-01-17,0.0,0.0,0.0,0.0,0.0,0.0,0.008579
2022-01-18,3.440073e-16,9.313226e-18,-6.737537e-16,9.720679e-17,5.276524e-16,-2.235174e-16,0.018243
2022-01-19,-0.01605432,-0.01388077,-0.004879632,-0.002123548,-0.0151603,-0.01838813,0.011393
2022-01-20,-0.01722743,-0.01002623,-0.003823177,0.00518691,-0.01722729,-0.02165027,0.020415
2022-01-21,-0.03198759,-0.02081952,-0.01410847,-0.008923909,-0.02865145,-0.03318878,0.011252


In [15]:
import pickle

rmse_files = []
df_rmse = pd.DataFrame()

for model in ["boot","gp","rl","mcts","emcts"]:
    file_path = f'out/backtests/{ex_num}/{model}/0-rmse.pkl'

    if model == "gp":
        file_path = f'out/backtests/{ex_num}/{model}/2-rmse.pkl'

    with open(file_path, 'rb') as file:
        rmse = pickle.load(file)
    
    df_rmse[model] = rmse["rmse"]

df_rmse.rename(columns={"boot": "Bootstrapped DQN", "gp": "GP", "rl": "Alpha Gen", "mcts":"RiskMiner", "emcts":"EMCTS"}, inplace=True)
df_rmse

Unnamed: 0_level_0,Bootstrapped DQN,GP,Alpha Gen,RiskMiner,EMCTS
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2022-01-17,156.382996,156.364229,155.590751,158.300835,153.579192
2022-01-18,158.007572,175.600771,175.201298,178.838977,178.648416
2022-01-19,162.555693,165.098265,167.156532,169.727847,167.463312
2022-01-20,153.544893,154.753435,153.756633,156.809519,153.584007
2022-01-21,156.745926,147.903446,141.124356,147.783289,140.416406
...,...,...,...,...,...
2023-12-25,174.864163,164.284786,159.584728,144.135415,155.800677
2023-12-26,158.374260,172.296495,167.520754,168.427782,170.531530
2023-12-27,160.737663,136.173317,135.451127,146.702907,133.993746
2023-12-28,151.563421,168.036628,162.556994,163.955759,158.251061


In [16]:
df_rmse_ma = df_rmse.rolling(30).mean()

In [17]:
df_com.columns

Index(['GP', 'Alpha Gen', 'Bootstrapped DQN', 'Oracle', 'MCTS', 'EMCTS',
       'Benchmark'],
      dtype='object')

In [18]:
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=2, cols=1,subplot_titles=["Cumulative Return", "RMSE"])


for col in ['GP', 'Alpha Gen', 'Bootstrapped DQN', 'MCTS', 'EMCTS', 'Benchmark']:#df_com.columns:
    fig.add_trace(
        go.Scatter(
            x=df_com.index,   
            y=df_com[col],    
            mode='lines',
            name=col,
            legendgroup = '1',
        ),
        row=1, 
        col=1
    )

for col in ['GP', 'Alpha Gen', 'Bootstrapped DQN', 'RiskMiner', 'EMCTS']:#df_rmse_ma.columns:
    fig.add_trace(
        go.Scatter(
            x=df_rmse_ma.index,
            y=df_rmse_ma[col],    
            mode='lines',
            name=col,
            legendgroup = '2',
        ),
        row=2, 
        col=1
    )

# Update the layout to add the title and template
fig.update_layout(
    template='seaborn',
    autosize=False,
    width=1200,
    height=1200,
    legend_tracegroupgap=580,
    legend_groupclick="toggleitem"
)

fig.show()


# fig = px.line(df_com, y=["Bootstrapped DQN","Alpha Gen","GP","Benchmark"], 
#             #   x="lifeExp", 
#               title='Cumulative Return',
#               template="seaborn",
#               )
# fig.show()

# Test

In [1]:
from alphagen_qlib.stock_data import StockData

data = StockData(
        instrument="csi300",
        start_time="2020-01-01",
        end_time="2022-01-01"
    )
data

[32156:MainThread](2025-09-13 08:42:51,074) INFO - qlib.Initialization - [config.py:420] - default_conf: client.
[32156:MainThread](2025-09-13 08:42:52,138) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[32156:MainThread](2025-09-13 08:42:52,139) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': WindowsPath('C:/Users/tywat/.qlib/qlib_data/cn_data')}


<alphagen_qlib.stock_data.StockData at 0x20c287f4ad0>

In [2]:
from qlib.data import D

instruments = data.stock_ids.tolist()

# Determine the proper start and end times for fetching price data.
# Here we use the same dates as in your StockData instance.
start_time = data._dates[data.max_backtrack_days].strftime("%Y-%m-%d")
end_time = data._dates[-data.max_future_days - 1].strftime("%Y-%m-%d")

# Query Qlib to get the closing price for each instrument.
# The field '$close' is used here (adjust if your field naming is different)
price_df = D.features(
    instruments=instruments,
    fields=["$close"],
    start_time="2020-01-01",
    end_time="2022-01-01"
)

price_df = price_df.reorder_levels(order=[1, 0])
price_df

Unnamed: 0_level_0,Unnamed: 1_level_0,$close
datetime,instrument,Unnamed: 2_level_1
2020-01-02,SH600000,14.791045
2020-01-03,SH600000,14.944963
2020-01-06,SH600000,14.778918
2020-01-07,SH600000,14.826492
2020-01-08,SH600000,14.612873
...,...,...
2021-12-27,SZ300999,1.155714
2021-12-28,SZ300999,1.140357
2021-12-29,SZ300999,1.116786
2021-12-30,SZ300999,1.124286


In [3]:
def compute_oracle_scores(price_df: pd.DataFrame) -> pd.DataFrame:
    # price_df is expected to be a MultiIndex DataFrame with (date, instrument)
    # Unstack to get dates as rows and instruments as columns
    price_unstacked = price_df.unstack(level=1)
    # Compute daily percentage returns and shift so that prediction on day t 
    # is compared with return from t to t+1
    oracle_signal = price_unstacked.pct_change().shift(-1)
    # Stack back to a MultiIndex DataFrame
    return oracle_signal.stack()


oracle_scores = compute_oracle_scores(price_df)
oracle_scores

NameError: name 'pd' is not defined

In [None]:
from alphagen_qlib.utils import load_alpha_pool_by_path

calc = QLibStockDataCalculator(data, None)

for p in Path("out/boot_dqn").iterdir():
        inst, size, seed, time, ver = p.name.split('_', 4)
        size, seed = int(size), int(seed)
        if inst != "csi300" or size != 20 or time < "20240923" or ver == "llm_d5":
            continue
        try:
            exprs, weights = load_alpha_pool_by_path(str(p / "249500_steps_pool.json"))
        except:
            continue

boot_score = data.make_dataframe(calc.make_ensemble_alpha(exprs, weights))
boot_score

Unnamed: 0_level_0,Unnamed: 1_level_0,0
datetime,instrument,Unnamed: 2_level_1
2020-01-02,SH600000,0.029499
2020-01-02,SH600004,0.065036
2020-01-02,SH600009,-0.070193
2020-01-02,SH600010,0.009689
2020-01-02,SH600011,0.194273
...,...,...
2021-12-31,SZ300782,-0.121243
2021-12-31,SZ300866,-0.022676
2021-12-31,SZ300888,0.029904
2021-12-31,SZ300896,-0.026693


In [None]:
def normalize_series(series: pd.Series) -> pd.Series:
    return (series - series.mean()) / series.std()

def rank_series_per_date(series: pd.Series) -> pd.Series:
    """
    Rank the series for each date (assumed to be the first level of the MultiIndex).
    The highest value is assigned rank 1.
    """
    return series.groupby(level=0).rank(ascending=False, method='min')
def compute_rmse_per_date(model_scores: pd.Series, oracle_scores: pd.Series) -> pd.DataFrame:
    """
    Compute the RMSE across stocks for each date.
    
    Parameters:
      model_scores: pd.Series with MultiIndex (date, instrument) containing your model's prediction scores.
      oracle_scores: pd.Series with MultiIndex (date, instrument) containing the oracle's prediction scores.
      
    Returns:
      A DataFrame with the date as the index and a column 'rmse' containing the RMSE for that date.
    """
    # normalize
    # model_scores = normalize_series(model_scores)
    # oracle_scores = normalize_series(oracle_scores)

    # rank the scores
    model_scores = rank_series_per_date(model_scores)
    oracle_scores = rank_series_per_date(oracle_scores)

    # Combine both series into one DataFrame
    df = pd.DataFrame({
        "model": model_scores,
        "oracle": oracle_scores
    })
    # Group by the date level. If your MultiIndex doesn't have names,
    # you can group by level=0 (assuming the first level is the date).
    rmse_series = df.groupby(level=0).apply(
        lambda group: np.sqrt(((group["oracle"] - group["model"]) ** 2).mean())
    )
    rmse_df = rmse_series.to_frame(name="rmse")
    # Ensure the index is named "date" (or adjust as needed)
    rmse_df.index.name = "date"
    return rmse_df

rmse_df = compute_rmse_per_date(boot_score.iloc[:,0], oracle_scores.iloc[:,0])
rmse_df

Unnamed: 0_level_0,rmse
date,Unnamed: 1_level_1
2020-01-02,175.373923
2020-01-03,172.032975
2020-01-06,158.534962
2020-01-07,143.087359
2020-01-08,158.674098
...,...
2021-12-27,165.069690
2021-12-28,162.478913
2021-12-29,165.498930
2021-12-30,164.577087


In [None]:
import plotly.express as px
fig = px.line(rmse_df, y=["rmse"], 
            #   x="lifeExp", 
              title='RMSE',
              template="seaborn",
              )
fig.show()