In [1]:
import json
import os
from typing import Optional, Tuple, List
from datetime import datetime
from pathlib import Path
from openai import OpenAI
import fire
import pandas as pd

import numpy as np
from sb3_contrib.ppo_mask import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback

from alphagen.data.expression import *
from alphagen.data.parser import ExpressionParser
from alphagen.models.linear_alpha_pool import LinearAlphaPool, MseAlphaPool
from alphagen.rl.env.wrapper import AlphaEnv
from alphagen.rl.policy import LSTMSharedNet
from alphagen.utils import reseed_everything, get_logger
from alphagen.rl.env.core import AlphaEnvCore
from alphagen_qlib.calculator import QLibStockDataCalculator
from alphagen_qlib.stock_data import initialize_qlib
from alphagen_llm.client import ChatClient, OpenAIClient, ChatConfig
from alphagen_llm.prompts.system_prompt import EXPLAIN_WITH_TEXT_DESC
from alphagen_llm.prompts.interaction import InterativeSession, DefaultInteraction

In [2]:
instruments: str = "csi300"
device = torch.device("cuda:0")


def get_dataset(start: str, end: str) -> StockData:
    return StockData(
        instrument=instruments,
        start_time=start,
        end_time=end,
        device=device
    )

segments = [
    ("2012-01-01", "2019-12-31"),
    ("2022-01-01", "2022-06-30"),
    ("2022-07-01", "2022-12-31"),
    ("2023-01-01", "2023-06-30")
]


datasets = [get_dataset(*s) for s in segments]

[12856:MainThread](2025-04-01 07:01:41,665) INFO - qlib.Initialization - [config.py:420] - default_conf: client.
[12856:MainThread](2025-04-01 07:01:42,811) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[12856:MainThread](2025-04-01 07:01:42,812) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': WindowsPath('C:/Users/tywat/.qlib/qlib_data/cn_data')}


In [3]:
close = Feature(FeatureType.CLOSE)
target = Ref(close, -20) / close - 1
calculators = [QLibStockDataCalculator(d, target) for d in datasets]

In [4]:
from alphagen.data.expression import Operators
from alphagen.data.parser import ExpressionParser

def load_linear_alpha_pool_from_json(json_path: str, 
                                     calculator: QLibStockDataCalculator,
                                     single_alpha: bool = False) -> LinearAlphaPool | list[LinearAlphaPool]:
    # Load the JSON file
    parser = ExpressionParser(Operators)
    with open(json_path, 'r') as f:
        pool_data = json.load(f)

    # Extract expressions and weights from the loaded data
    expressions = pool_data['exprs']
    weights = pool_data['weights']

    # Create an instance of LinearAlphaPool
    alpha_pool = MseAlphaPool(
        capacity=len(expressions),  # Set the capacity based on the number of expressions
        calculator=calculator
    )

    # Load the expressions into the pool
    expres = []
    if single_alpha:
        alpha_pools = []

        for expression,weight in zip(expressions,weights):
            alpha_pool = MseAlphaPool(
                capacity=1,
                calculator=calculator
                )
            expre = parser.parse(expression)
            alpha_pool.force_load_exprs([expre], [weight])
            alpha_pools.append(alpha_pool)

        return  alpha_pools
    else:
        for expression in expressions:
            expre = parser.parse(expression)
            expres.append(expre)
        
        
        alpha_pool.force_load_exprs(expres, weights)

        return alpha_pool

alpha_pools = load_linear_alpha_pool_from_json('out/results/csi300_20_0_20250208124320_rl/251904_steps_pool.json', calculators[1])
alpha_pool = load_linear_alpha_pool_from_json('out/results/csi300_20_0_20250208124320_rl/251904_steps_pool.json', calculators[1], single_alpha=True)

In [5]:
ic_value, rank_ic_value = alpha_pools.test_ensemble(calculators[2])
print(alpha_pools.exprs)
print(ic_value, rank_ic_value)

[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-2.0), Delta(Log($vwap),1d), Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($high,-30.0),40d)),40d),-0.01)), Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(Log($low),5d)),30.0)))), Mad(Add(2.0,Mean($vwap,20d)),10d), Corr($close,$low,10d), Abs(Log(Mad(Sub(-0.5,$close),20d))), Mad(Log(Log($volume)),40d), Mul(0.5,Corr(Log($volume),WMA(Log($volume),40d),40d)), Mul(Mul($volume,Mul(Add(Mean($high,20d),30.0),$high)),0.5), Mul(WMA(Log(Abs(Var($low,5d))),20d),-2.0), Abs(Mul(5.0,Sub($open,30.0))), Mean(Less(Sub(-2.0,Corr($volume,$high,20d)),1.0),10d), Sub(Less(1.0,$low),5.0), Add(Corr(Sub(-1.0,$high),$volume,10d),0.01), WMA(Div(Std(WMA(Div(Div($vwap,30.0),$low),40d),20d),-5.0),10d), WMA(Sub(-1.0,Div($low,$close)),20d), Less(Div($close,$vwap),$volume), Sub(Mad(Mean(Log($low),20d),40d),5.0), None]
0.06614601612091064 0.0644562840461731


In [6]:
alpha_index = 3

ic_value, rank_ic_value = alpha_pool[alpha_index].test_ensemble(calculators[2])
print(alpha_pool[alpha_index].exprs)
print(ic_value, rank_ic_value)

[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), None]
0.010267447680234909 0.010892813093960285


In [7]:
for alpha in alpha_pool:
    print(alpha.exprs)

[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-2.0), None]
[Delta(Log($vwap),1d), None]
[Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($high,-30.0),40d)),40d),-0.01)), None]
[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0),10d), None]
[Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(Log($low),5d)),30.0)))), None]
[Mad(Add(2.0,Mean($vwap,20d)),10d), None]
[Corr($close,$low,10d), None]
[Abs(Log(Mad(Sub(-0.5,$close),20d))), None]
[Mad(Log(Log($volume)),40d), None]
[Mul(0.5,Corr(Log($volume),WMA(Log($volume),40d),40d)), None]
[Mul(Mul($volume,Mul(Add(Mean($high,20d),30.0),$high)),0.5), None]
[Mul(WMA(Log(Abs(Var($low,5d))),20d),-2.0), None]
[Abs(Mul(5.0,Sub($open,30.0))), None]
[Mean(Less(Sub(-2.0,Corr($volume,$high,20d)),1.0),10d), None]
[Sub(Less(1.0,$low),5.0), None]
[Add(Corr(Sub(-1.0,$high),$volume,10d),0.01), None]
[WMA(Div(Std(WMA(Div(Div($vwap,30.0),$low),40d),20d),-5.0),10d), None]
[WMA(Sub(-1.0,Div($low,$close)),20d), None]
[Less(Div($close,$vwap),$volume), None]
[Sub(Mad(Mean(Log($low),

In [8]:
ics = []
rank_ics = []
alphas = []

for alpha in alpha_pool:
    ic_value, rank_ic_value = alpha.test_ensemble(calculators[2])

    ics.append(ic_value)
    rank_ics.append(rank_ic_value)
    alphas.append(alpha.exprs)

df_ic_ind = pd.DataFrame({'alpha': alphas, 'ic': ics, 'rank_ic': rank_ics})
df_ic_ind

Unnamed: 0,alpha,ic,rank_ic
0,"[Greater(Div(Div(-1.0,$high),EMA($open,10d)),-...",0.055708,0.084874
1,"[Delta(Log($vwap),1d), None]",-0.02456,-0.012933
2,"[Mul($volume,Mul(Cov($close,Mul(5.0,Min(Mul($h...",-0.036502,-0.035545
3,"[Sum(Mul(Corr(Div($vwap,-0.5),$close,5d),-10.0...",0.010267,0.010893
4,"[Abs(Sub(2.0,Div($close,Add(Greater(2.0,Delta(...",-0.061593,-0.092103
5,"[Mad(Add(2.0,Mean($vwap,20d)),10d), None]",-0.006636,0.043511
6,"[Corr($close,$low,10d), None]",0.056122,0.063845
7,"[Abs(Log(Mad(Sub(-0.5,$close),20d))), None]",-0.081099,-0.097603
8,"[Mad(Log(Log($volume)),40d), None]",-0.025405,-0.042726
9,"[Mul(0.5,Corr(Log($volume),WMA(Log($volume),40...",-0.040783,-0.046877


In [9]:
for p in Path("out/gp").iterdir():
    seed = int(p.name)

with open(p / "40.json") as f:
    report = json.load(f)


state = report["res"]["res"]["pool_state"]
state["exprs"]

['Std(EMA(Min(Mul(5.0,$high),30d),50d),10d)',
 'Std(EMA(Min(Log($vwap),30d),40d),10d)',
 'Sum(Mean(Abs(Corr($low,$high,20d)),40d),20d)',
 'Std(Med(Min(Mul(5.0,$high),30d),10d),10d)',
 'Mad(Min($low,20d),20d)',
 'Std(Std(Min(Mul($vwap,2.0),30d),10d),10d)',
 'Mad(Med($close,50d),10d)',
 'Std(Cov(Corr(Var($volume,40d),$high,20d),$close,30d),10d)',
 'Std(Med(Ref(Mul(10.0,$high),30d),10d),10d)',
 'Std(Min(Sum(Mul(10.0,$high),40d),50d),10d)',
 'Mad(Min($high,30d),10d)',
 'Std(Std(Med(Mul(0.5,$high),20d),20d),10d)',
 'Mad(Ref(Min($high,30d),10d),10d)',
 'Std(Abs(WMA(Cov(0.01,$high,50d),10d)),10d)',
 'Std(Min(WMA($high,20d),50d),10d)',
 'Log(Var(Sum($low,30d),40d))',
 'Std(EMA(Min(Mul(Std($high,10d),$high),30d),50d),10d)',
 'Std(Max(Min(Mul(5.0,$high),30d),20d),10d)',
 'Std(Min(Mean(Corr(5.0,$high,30d),50d),10d),10d)',
 'Std(EMA(WMA($vwap,10d),40d),10d)']

In [2]:
ex_num = "51-5"

In [3]:
import pickle

file_path = f'out/backtests/{ex_num}/gp/2-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [4]:
import pickle

file_path = f'out/backtests/{ex_num}/gp/2-report.pkl'

with open(file_path, 'rb') as file:
    gp_report = pickle.load(file)
gp_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2020-01-02,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.013587
2020-01-03,9.985773e+07,-1.242734e-16,9.484930e+07,0.948493,1.422740e+05,0.001423,9.484930e+07,5.008421e+06,-0.001753
2020-01-06,9.968490e+07,-1.380534e-03,1.181589e+08,0.233428,1.772383e+05,0.000350,9.897084e+07,7.140661e+05,-0.003778
2020-01-07,1.005576e+08,9.067840e-03,1.389572e+08,0.208641,2.084358e+05,0.000313,1.000074e+08,5.502283e+05,0.007490
2020-01-08,9.948802e+07,-1.045052e-02,1.514441e+08,0.124176,2.271661e+05,0.000186,9.916298e+07,3.250440e+05,-0.011516
...,...,...,...,...,...,...,...,...,...
2021-12-27,1.502713e+08,-1.131759e-04,7.002606e+09,0.182514,1.050391e+07,0.000274,1.495571e+08,7.141347e+05,-0.000410
2021-12-28,1.515162e+08,8.461025e-03,7.020299e+09,0.117739,1.053045e+07,0.000177,1.510521e+08,4.641126e+05,0.007448
2021-12-29,1.487223e+08,-1.827460e-02,7.036949e+09,0.109891,1.055542e+07,0.000165,1.482908e+08,4.314993e+05,-0.014625
2021-12-30,1.505484e+08,1.245313e-02,7.054229e+09,0.116189,1.058134e+07,0.000174,1.500913e+08,4.571258e+05,0.007787


In [5]:
import pickle

file_path = f'out/backtests/{ex_num}/rl/0-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [6]:
import pickle

file_path = f'out/backtests/{ex_num}/rl/0-report.pkl'

with open(file_path, 'rb') as file:
    alphaGen_report = pickle.load(file)
alphaGen_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2020-01-02,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.013587
2020-01-03,9.985766e+07,-4.208414e-16,9.489186e+07,0.948919,1.423378e+05,0.001423,9.489186e+07,4.965801e+06,-0.001753
2020-01-06,9.944416e+07,-3.792852e-03,1.180644e+08,0.232056,1.770966e+05,0.000348,9.873570e+07,7.084600e+05,-0.003778
2020-01-07,1.004760e+08,1.074233e-02,1.423650e+08,0.244364,2.135475e+05,0.000367,9.981191e+07,6.640592e+05,0.007490
2020-01-08,9.957497e+07,-8.648418e-03,1.637261e+08,0.212599,2.455891e+05,0.000319,9.901555e+07,5.594233e+05,-0.011516
...,...,...,...,...,...,...,...,...,...
2021-12-27,1.516544e+08,2.265028e-03,1.294456e+10,0.213837,1.941685e+07,0.000321,1.507879e+08,8.664104e+05,-0.000410
2021-12-28,1.528029e+08,7.880469e-03,1.297562e+10,0.204780,1.946343e+07,0.000307,1.520019e+08,8.010219e+05,0.007448
2021-12-29,1.515947e+08,-7.603732e-03,1.300651e+10,0.202148,1.950976e+07,0.000303,1.507677e+08,8.269453e+05,-0.014625
2021-12-30,1.535894e+08,1.346866e-02,1.303790e+10,0.207057,1.955685e+07,0.000311,1.527239e+08,8.654367e+05,0.007787


In [7]:
alphaGen_report["cum_return"] = alphaGen_report["return"].cumsum()
alphaGen_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench,cum_return
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2020-01-02,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.013587,0.000000e+00
2020-01-03,9.985766e+07,-4.208414e-16,9.489186e+07,0.948919,1.423378e+05,0.001423,9.489186e+07,4.965801e+06,-0.001753,-4.208414e-16
2020-01-06,9.944416e+07,-3.792852e-03,1.180644e+08,0.232056,1.770966e+05,0.000348,9.873570e+07,7.084600e+05,-0.003778,-3.792852e-03
2020-01-07,1.004760e+08,1.074233e-02,1.423650e+08,0.244364,2.135475e+05,0.000367,9.981191e+07,6.640592e+05,0.007490,6.949482e-03
2020-01-08,9.957497e+07,-8.648418e-03,1.637261e+08,0.212599,2.455891e+05,0.000319,9.901555e+07,5.594233e+05,-0.011516,-1.698936e-03
...,...,...,...,...,...,...,...,...,...,...
2021-12-27,1.516544e+08,2.265028e-03,1.294456e+10,0.213837,1.941685e+07,0.000321,1.507879e+08,8.664104e+05,-0.000410,6.139802e-01
2021-12-28,1.528029e+08,7.880469e-03,1.297562e+10,0.204780,1.946343e+07,0.000307,1.520019e+08,8.010219e+05,0.007448,6.218607e-01
2021-12-29,1.515947e+08,-7.603732e-03,1.300651e+10,0.202148,1.950976e+07,0.000303,1.507677e+08,8.269453e+05,-0.014625,6.142570e-01
2021-12-30,1.535894e+08,1.346866e-02,1.303790e+10,0.207057,1.955685e+07,0.000311,1.527239e+08,8.654367e+05,0.007787,6.277256e-01


In [8]:
import pickle

file_path = f'out/backtests/{ex_num}/boot/0-graph.pkl'

with open(file_path, 'rb') as file:
    chart = pickle.load(file)
chart.show()

In [9]:
import pickle

file_path = f'out/backtests/{ex_num}/boot/0-report.pkl'

with open(file_path, 'rb') as file:
    boot_report = pickle.load(file)
boot_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2020-01-02,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.013587
2020-01-03,9.985775e+07,-9.953510e-17,9.483513e+07,0.948351,1.422527e+05,0.001423,9.483513e+07,5.022612e+06,-0.001753
2020-01-06,9.977135e+07,-5.193526e-04,1.178589e+08,0.230565,1.767883e+05,0.000346,9.906912e+07,7.022307e+05,-0.003778
2020-01-07,1.006189e+08,8.795304e-03,1.378570e+08,0.200440,2.067855e+05,0.000301,1.000959e+08,5.229807e+05,0.007490
2020-01-08,9.956444e+07,-1.016828e-02,1.587298e+08,0.207444,2.380948e+05,0.000311,9.900905e+07,5.553955e+05,-0.011516
...,...,...,...,...,...,...,...,...,...
2021-12-27,1.439737e+08,3.339626e-03,1.304548e+10,0.184451,1.956823e+07,0.000277,1.432824e+08,6.913612e+05,-0.000410
2021-12-28,1.448650e+08,6.496408e-03,1.307482e+10,0.203729,1.961222e+07,0.000306,1.441061e+08,7.589047e+05,0.007448
2021-12-29,1.439949e+08,-5.705631e-03,1.310384e+10,0.200356,1.965576e+07,0.000301,1.432106e+08,7.843054e+05,-0.014625
2021-12-30,1.459380e+08,1.379976e-02,1.313319e+10,0.203847,1.969979e+07,0.000306,1.451419e+08,7.961016e+05,0.007787


In [10]:
import pickle

file_path = f'out/backtests/{ex_num}/oracle/0-report.pkl'

with open(file_path, 'rb') as file:
    oracle_report = pickle.load(file)
oracle_report

Unnamed: 0_level_0,account,return,total_turnover,turnover,total_cost,cost,value,cash,bench
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2020-01-02,1.000000e+08,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,0.000000,0.000000e+00,1.000000e+08,0.013587
2020-01-03,9.985761e+07,-1.094304e-16,9.492996e+07,0.949300,1.423949e+05,0.001424,9.492996e+07,4.927641e+06,-0.001753
2020-01-06,1.014676e+08,1.646398e-02,1.176398e+08,0.227422,1.764597e+05,0.000341,1.007688e+08,6.987897e+05,-0.003778
2020-01-07,1.020412e+08,5.962166e-03,1.385171e+08,0.205754,2.077757e+05,0.000309,1.014985e+08,5.427770e+05,0.007490
2020-01-08,1.011334e+08,-8.619890e-03,1.573386e+08,0.184450,2.360079e+05,0.000277,1.006407e+08,4.927372e+05,-0.011516
...,...,...,...,...,...,...,...,...,...
2021-12-24,1.518477e+08,2.474845e-03,1.289808e+10,0.193284,1.934711e+07,0.000290,1.510853e+08,7.623950e+05,-0.005537
2021-12-27,1.521458e+08,2.249083e-03,1.292704e+10,0.190733,1.939056e+07,0.000286,1.514002e+08,7.455896e+05,-0.000410
2021-12-28,1.502278e+08,-1.231781e-02,1.295630e+10,0.192308,1.943445e+07,0.000288,1.494662e+08,7.615853e+05,0.007448
2021-12-29,1.485774e+08,-1.070355e-02,1.298454e+10,0.188020,1.947681e+07,0.000282,1.478406e+08,7.368233e+05,-0.014625


In [11]:
import pandas as pd
df_com = pd.DataFrame()

df_com["GP"] = gp_report["return"].cumsum()
df_com["Alpha Gen"] = alphaGen_report["return"].cumsum()
df_com["Bootstrapped DQN"] = boot_report["return"].cumsum()
df_com["Oracle"] = oracle_report["return"].cumsum()
df_com["Benchmark"] = boot_report["bench"].cumsum()


df_com.head()

Unnamed: 0_level_0,GP,Alpha Gen,Bootstrapped DQN,Oracle,Benchmark
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2020-01-02,0.0,0.0,0.0,0.0,0.013587
2020-01-03,-1.242734e-16,-4.208414e-16,-9.95351e-17,-1.094304e-16,0.011834
2020-01-06,-0.001380534,-0.003792852,-0.0005193526,0.01646398,0.008056
2020-01-07,0.007687306,0.006949482,0.008275951,0.02242615,0.015546
2020-01-08,-0.002763217,-0.001698936,-0.001892334,0.01380626,0.00403


In [12]:
import pickle

rmse_files = []
df_rmse = pd.DataFrame()

for model in ["boot","gp","rl"]:
    file_path = f'out/backtests/{ex_num}/{model}/0-rmse.pkl'

    if model == "gp":
        file_path = f'out/backtests/{ex_num}/{model}/2-rmse.pkl'

    with open(file_path, 'rb') as file:
        rmse = pickle.load(file)
    
    df_rmse[model] = rmse["rmse"]

df_rmse.rename(columns={"boot": "Bootstrapped DQN", "gp": "GP", "rl": "Alpha Gen"}, inplace=True)
df_rmse

Unnamed: 0_level_0,Bootstrapped DQN,GP,Alpha Gen
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2020-01-02,175.373923,174.847309,177.452293
2020-01-03,172.032975,166.355893,172.897538
2020-01-06,158.534962,168.291778,167.316096
2020-01-07,143.087359,164.291739,146.301059
2020-01-08,158.674098,158.414047,160.185328
...,...,...,...
2021-12-27,165.069690,167.879158,167.118457
2021-12-28,162.478913,182.799636,167.501603
2021-12-29,165.498930,166.687620,165.948379
2021-12-30,164.577087,168.267198,167.694103


In [38]:
df_rmse_ma = df_rmse.rolling(30).mean()

In [39]:
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=2, cols=1,subplot_titles=["Cumulative Return", "RMSE"])


for col in df_com.columns:
    fig.add_trace(
        go.Scatter(
            x=df_com.index,   
            y=df_com[col],    
            mode='lines',
            name=col,
            legendgroup = '1',
        ),
        row=1, 
        col=1
    )

for col in df_rmse_ma.columns:
    fig.add_trace(
        go.Scatter(
            x=df_rmse_ma.index,
            y=df_rmse_ma[col],    
            mode='lines',
            name=col,
            legendgroup = '2',
        ),
        row=2, 
        col=1
    )

# Update the layout to add the title and template
fig.update_layout(
    template='seaborn',
    autosize=False,
    width=1200,
    height=1200,
    legend_tracegroupgap=580,
    legend_groupclick="toggleitem"
)

fig.show()


# fig = px.line(df_com, y=["Bootstrapped DQN","Alpha Gen","GP","Benchmark"], 
#             #   x="lifeExp", 
#               title='Cumulative Return',
#               template="seaborn",
#               )
# fig.show()

In [14]:
from alphagen_qlib.stock_data import StockData

data = StockData(
        instrument="csi300",
        start_time="2020-01-01",
        end_time="2022-01-01"
    )
data

[29000:MainThread](2025-04-05 08:46:52,432) INFO - qlib.Initialization - [config.py:420] - default_conf: client.
[29000:MainThread](2025-04-05 08:46:53,465) INFO - qlib.Initialization - [__init__.py:74] - qlib successfully initialized based on client settings.
[29000:MainThread](2025-04-05 08:46:53,466) INFO - qlib.Initialization - [__init__.py:76] - data_path={'__DEFAULT_FREQ': WindowsPath('C:/Users/tywat/.qlib/qlib_data/cn_data')}


<alphagen_qlib.stock_data.StockData at 0x2838b522110>

In [15]:
from qlib.data import D

instruments = data.stock_ids.tolist()

# Determine the proper start and end times for fetching price data.
# Here we use the same dates as in your StockData instance.
start_time = data._dates[data.max_backtrack_days].strftime("%Y-%m-%d")
end_time = data._dates[-data.max_future_days - 1].strftime("%Y-%m-%d")

# Query Qlib to get the closing price for each instrument.
# The field '$close' is used here (adjust if your field naming is different)
price_df = D.features(
    instruments=instruments,
    fields=["$close"],
    start_time="2020-01-01",
    end_time="2022-01-01"
)

price_df = price_df.reorder_levels(order=[1, 0])
price_df

Unnamed: 0_level_0,Unnamed: 1_level_0,$close
datetime,instrument,Unnamed: 2_level_1
2020-01-02,SH600000,14.791045
2020-01-03,SH600000,14.944963
2020-01-06,SH600000,14.778918
2020-01-07,SH600000,14.826492
2020-01-08,SH600000,14.612873
...,...,...
2021-12-27,SZ300999,1.155714
2021-12-28,SZ300999,1.140357
2021-12-29,SZ300999,1.116786
2021-12-30,SZ300999,1.124286


In [16]:
def compute_oracle_scores(price_df: pd.DataFrame) -> pd.DataFrame:
    # price_df is expected to be a MultiIndex DataFrame with (date, instrument)
    # Unstack to get dates as rows and instruments as columns
    price_unstacked = price_df.unstack(level=1)
    # Compute daily percentage returns and shift so that prediction on day t 
    # is compared with return from t to t+1
    oracle_signal = price_unstacked.pct_change().shift(-1)
    # Stack back to a MultiIndex DataFrame
    return oracle_signal.stack()


oracle_scores = compute_oracle_scores(price_df)
oracle_scores


The default fill_method='pad' in DataFrame.pct_change is deprecated and will be removed in a future version. Either fill in any non-leading NA values prior to calling pct_change or specify 'fill_method=None' to not fill NA values.





Unnamed: 0_level_0,Unnamed: 1_level_0,$close
datetime,instrument,Unnamed: 2_level_1
2020-01-02,SH600000,0.010406
2020-01-02,SH600004,-0.007961
2020-01-02,SH600009,-0.000765
2020-01-02,SH600010,0.008123
2020-01-02,SH600011,0.000000
...,...,...
2021-12-30,SZ300782,-0.024249
2021-12-30,SZ300866,-0.033949
2021-12-30,SZ300888,0.008192
2021-12-30,SZ300896,-0.012722


In [17]:
from alphagen_qlib.utils import load_alpha_pool_by_path

calc = QLibStockDataCalculator(data, None)

for p in Path("out/boot_dqn").iterdir():
        inst, size, seed, time, ver = p.name.split('_', 4)
        size, seed = int(size), int(seed)
        if inst != "csi300" or size != 20 or time < "20240923" or ver == "llm_d5":
            continue
        try:
            exprs, weights = load_alpha_pool_by_path(str(p / "249500_steps_pool.json"))
        except:
            continue

boot_score = data.make_dataframe(calc.make_ensemble_alpha(exprs, weights))
boot_score

Unnamed: 0_level_0,Unnamed: 1_level_0,0
datetime,instrument,Unnamed: 2_level_1
2020-01-02,SH600000,0.029499
2020-01-02,SH600004,0.065036
2020-01-02,SH600009,-0.070193
2020-01-02,SH600010,0.009689
2020-01-02,SH600011,0.194273
...,...,...
2021-12-31,SZ300782,-0.121243
2021-12-31,SZ300866,-0.022676
2021-12-31,SZ300888,0.029904
2021-12-31,SZ300896,-0.026693


In [18]:
def normalize_series(series: pd.Series) -> pd.Series:
    return (series - series.mean()) / series.std()

def rank_series_per_date(series: pd.Series) -> pd.Series:
    """
    Rank the series for each date (assumed to be the first level of the MultiIndex).
    The highest value is assigned rank 1.
    """
    return series.groupby(level=0).rank(ascending=False, method='min')
def compute_rmse_per_date(model_scores: pd.Series, oracle_scores: pd.Series) -> pd.DataFrame:
    """
    Compute the RMSE across stocks for each date.
    
    Parameters:
      model_scores: pd.Series with MultiIndex (date, instrument) containing your model's prediction scores.
      oracle_scores: pd.Series with MultiIndex (date, instrument) containing the oracle's prediction scores.
      
    Returns:
      A DataFrame with the date as the index and a column 'rmse' containing the RMSE for that date.
    """
    # normalize
    # model_scores = normalize_series(model_scores)
    # oracle_scores = normalize_series(oracle_scores)

    # rank the scores
    model_scores = rank_series_per_date(model_scores)
    oracle_scores = rank_series_per_date(oracle_scores)

    # Combine both series into one DataFrame
    df = pd.DataFrame({
        "model": model_scores,
        "oracle": oracle_scores
    })
    # Group by the date level. If your MultiIndex doesn't have names,
    # you can group by level=0 (assuming the first level is the date).
    rmse_series = df.groupby(level=0).apply(
        lambda group: np.sqrt(((group["oracle"] - group["model"]) ** 2).mean())
    )
    rmse_df = rmse_series.to_frame(name="rmse")
    # Ensure the index is named "date" (or adjust as needed)
    rmse_df.index.name = "date"
    return rmse_df

rmse_df = compute_rmse_per_date(boot_score.iloc[:,0], oracle_scores.iloc[:,0])
rmse_df

Unnamed: 0_level_0,rmse
date,Unnamed: 1_level_1
2020-01-02,175.373923
2020-01-03,172.032975
2020-01-06,158.534962
2020-01-07,143.087359
2020-01-08,158.674098
...,...
2021-12-27,165.069690
2021-12-28,162.478913
2021-12-29,165.498930
2021-12-30,164.577087


In [19]:
import plotly.express as px
fig = px.line(rmse_df, y=["rmse"], 
            #   x="lifeExp", 
              title='RMSE',
              template="seaborn",
              )
fig.show()