In [1]:
import deap
import json
import operator
import deap.tools as tools
import deap.base as base
import numpy as np
import random
import inspect
import pandas as pd
from functools import partial
import sys
from deap import base, creator, tools, gp, algorithms
from sklearn.linear_model import LinearRegression
import multiprocessing
from joblib import Parallel, delayed
from tqdm import tqdm
from numba.typed import List 
from npeet.entropy_estimators import mi
import deap.tools as tools
import warnings
import swifter
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import mutual_info_regression
import cupyx
import torch
import numba
from numba import jit
from numba import njit

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.simplefilter("ignore")

def load_data(data):
    data = pd.read_pickle(f'D:redata/DEAP/ad_hoc_prod/fields_full/{data}.pk')
    data = data.loc['2017':]
    data = data.stack()
    data = data.replace(0,np.nan)
    return data
    
adj_return = load_data("adj_return")
adj_close = load_data("adj_close")
adj_open = load_data("adj_open")
adj_high = load_data("adj_high")
adj_low = load_data("adj_low")
close = load_data("close")
open = load_data("open")
high = load_data("high")
low = load_data("low")
num_trades = load_data("num_trades")
total_turnover = load_data("total_turnover")
turnover_rate = load_data("turnover_rate")
circulation_a = load_data("circulation_a")
total_a = load_data("total_a")
circulation_market_value = load_data("circulation_market_value")
volume = load_data("volume")

data_origin_factors = pd.concat([open, high, low, close, volume, total_turnover, turnover_rate, adj_return], axis=1)
data_origin_factors .columns = ['open', 'high', 'low', 'close', 'volume', 'total_turnover', 'turnover_rate', 'adj_return']

def calculate_twap(df):
    df['twap'] = (df['high'] + df['low'] + df['close'] + df['open']) / 4
    return df['twap']

def calculate_vwap(df):
    df['vwap'] = df['total_turnover']/df['volume']
    return df['vwap']

data_origin_factors['twap'] = data_origin_factors.groupby('ticker', group_keys=False).apply(calculate_twap)
data_origin_factors['vwap'] = data_origin_factors.groupby('ticker', group_keys=False).apply(calculate_vwap)

# 设置target为隔夜收益率,下一天的收盘买，下二天的收盘卖
adj_close_stk = pd.read_pickle('D:redata/DEAP/ad_hoc_prod/fields_full/adj_close.pk')
adj_open_stk = pd.read_pickle('D:redata/DEAP/ad_hoc_prod/fields_full/adj_open.pk')

limit_down = pd.read_pickle('D:redata/DEAP/ad_hoc_prod/fields_full/limit_down.pk') #跌停
limit_down = limit_down[(limit_down.index >= '2017-01-01')]
limit_up = pd.read_pickle('D:redata/DEAP/ad_hoc_prod/fields_full/limit_up.pk') #涨停
limit_up = limit_up[(limit_up.index >= '2017-01-01')]

halt_status = pd.read_pickle('D:redata/DEAP/ad_hoc_prod/fields_full/halt_status.pk') #停牌
halt_status = halt_status[(halt_status.index >= '2017-01-01')]
st_status = pd.read_pickle('D:redata/DEAP/ad_hoc_prod/fields_full/st_status.pk') #st
st_status = st_status[(st_status.index >= '2017-01-01')]
st_status = st_status.replace(True,np.nan).replace(False,True)
halt_status = halt_status.replace(True,np.nan).replace(False,True)

adj_target_ret_o1o2 = adj_open_stk.shift(-2)/adj_open_stk.shift(-1) - 1
adj_target_ret_o1o2 = adj_target_ret_o1o2.replace(np.inf,np.nan).replace(-np.inf,np.nan)

# # 设置target为隔夜收益率,下一天的收盘买，下五天的收盘卖
# adj_target_ret_o1o5 = adj_open_stk.shift(-5)/adj_open_stk.shift(-1) - 1
# adj_target_ret_o1o5 = adj_target_ret_o1o5.replace(np.inf,np.nan).replace(-np.inf,np.nan)

target_pre = pd.DataFrame(adj_target_ret_o1o2.stack())
target_pre.columns = ['target']

data_merged = pd.merge(data_origin_factors, target_pre, on=['date', 'ticker'], how='left')
data_merged = data_merged.dropna(axis = 0)

In [2]:
def compute_condition(new_stk):
    drop_new_stk = adj_open.rolling(window=new_stk, min_periods=new_stk).mean()
    drop_new_stk = ~np.isnan(drop_new_stk)
    drop_new_stk = drop_new_stk.unstack()
    df = st_status * halt_status * drop_new_stk
    df = df.replace(True, 1).replace(False, np.nan)
    return df.loc['2017':]

final_universe = compute_condition(new_stk = 20) ##剔除20日内新股，st，halt

# === 数据集拆分 (7:3) ===
full_data = data_merged.loc['2017-01-03': '2024-12-31']
# 选择2022年以前的数据作为训练集
train_data = full_data[full_data.index.get_level_values('date') <= '2021-06-30']

# 选择2022年及以后的数据作为验证集
val_data = full_data[full_data.index.get_level_values('date') > '2021-06-30']

# Define data and target
train_data_set = train_data.iloc[:, :-1]
train_data_target = train_data.iloc[:, -1:]

val_data_set = val_data.iloc[:, :-1]
val_data_target = val_data.iloc[:, -1:]

target = full_data.iloc[:, -1:]
target_clean = target['target'].unstack()
target_clean = target_clean.replace([np.inf, -np.inf], np.nan)
target_clean = (target_clean * final_universe)
target_clean = target_clean.dropna(how='all',axis=0)
target_clean = target_clean.drop_duplicates()

train_clean_target = target_clean[target_clean.index.get_level_values('date') <= '2021-06-30']
val_clean_target = target_clean[target_clean.index.get_level_values('date') > '2021-06-30']
train_clean_rank_target = train_clean_target.rank(ascending=True, axis=1)
val_clean_rank_target = val_clean_target.rank(ascending=True, axis=1)

#临时，没做分层
full_data_factors = full_data.iloc[:, :-1]
full_data_clean_rank_target = target_clean.rank(axis=1, ascending=True)

In [34]:
from math import pow
def ensure_multiindex(x: pd.Series, template: pd.Series) -> pd.Series:
    if not isinstance(x.index, pd.MultiIndex):
        x.index = template.index
    else:
        x = x.reindex(template.index)
    return x

# -------------------- Numba核心滚动函数 --------------------

@njit
def rolling_mean(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    
    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]
        
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        if (i - start_pos + 1) >= window:
            window_start = i - window + 1
            s = 0.0
            count = 0
            for j in range(window_start, i + 1):
                if not np.isnan(values[j]):
                    s += values[j]
                    count += 1
            if count > 0:
                result[i] = s / count
            else:
                result[i] = np.nan
        else:
            result[i] = np.nan

    return result
    
@njit
def rolling_std(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)

    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]
        
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        if (i - start_pos + 1) >= window:
            window_start = i - window + 1
            s = 0.0
            count = 0
            for j in range(window_start, i + 1):
                if not np.isnan(values[j]):
                    s += values[j]
                    count += 1

            if count > 1:
                mean = s / count
                var = 0.0
                for j in range(window_start, i + 1):
                    if not np.isnan(values[j]):
                        var += (values[j] - mean) ** 2
                result[i] = (var / (count - 1)) ** 0.5
            else:
                result[i] = np.nan
        else:
            result[i] = np.nan

    return result
    
@njit
def rolling_skew(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        if i - start_pos + 1 >= window:
            count = 0
            mean = 0.0
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    mean += val
                    count += 1
            if count == 0:
                result[i] = np.nan
                continue
            mean /= count

            # 计算方差和三阶矩
            m2 = 0.0
            m3 = 0.0
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    diff = val - mean
                    m2 += diff ** 2
                    m3 += diff ** 3
            m2 /= count
            m3 /= count

            if m2 > 0:
                skew = m3 / (m2 ** 1.5)
                result[i] = skew
            else:
                result[i] = 0.0
        else:
            result[i] = np.nan

    return result

@njit
def rolling_kurt(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        if i - start_pos + 1 >= window:
            count = 0
            mean = 0.0
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    mean += val
                    count += 1
            if count == 0:
                result[i] = np.nan
                continue
            mean /= count

            # 计算方差和四阶矩
            m2 = 0.0
            m4 = 0.0
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    diff = val - mean
                    m2 += diff ** 2
                    m4 += diff ** 4
            m2 /= count
            m4 /= count

            if m2 > 0:
                kurt = m4 / (m2 ** 2)
                result[i] = kurt
            else:
                result[i] = 0.0
        else:
            result[i] = np.nan

    return result

@njit
def rolling_prod(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker 切换，重置起点
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        # 判断窗口是否满足
        if i - start_pos + 1 >= window:
            p = 1.0
            valid_count = 0

            # 计算乘积
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    p *= val
                    valid_count += 1

            if valid_count == window:
                result[i] = p
            else:
                result[i] = np.nan
        else:
            result[i] = np.nan

    return result

@njit
def rolling_argmin(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    start_pos = 0
    last_ticker = tickers[0]
    window_values = List.empty_list(np.float64)

    for i in range(n):
        current_ticker = tickers[i]

        # ticker 切换，重置起点
        if current_ticker != last_ticker:
            start_pos, last_ticker = i, current_ticker
            window_values.clear()

        # 收集有效值
        if not np.isnan(values[i]):
            window_values.append(values[i])

        # 判断窗口是否超出，清理旧值
        if i - start_pos + 1 > window:
            old_value = values[i - window]
            if not np.isnan(old_value):
                for idx in range(len(window_values)):
                    if window_values[idx] == old_value:
                        window_values.pop(idx)
                        break

        # 计算最小值及其索引
        if len(window_values) >= window:
            min_val = window_values[0]
            min_idx = 0
            for j in range(1, len(window_values)):
                if window_values[j] < min_val:
                    min_val = window_values[j]
                    min_idx = j
            result[i] = min_idx
        else:
            result[i] = np.nan

    return result

@njit
def rolling_argmax(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    start_pos = 0
    last_ticker = tickers[0]
    window_values = List.empty_list(np.float64)

    for i in range(n):
        current_ticker = tickers[i]

        # ticker 切换，重置起点
        if current_ticker != last_ticker:
            start_pos, last_ticker = i, current_ticker
            window_values.clear()

        # 收集有效值
        if not np.isnan(values[i]):
            window_values.append(values[i])

        # 判断窗口是否超出，清理旧值
        if i - start_pos + 1 > window:
            old_value = values[i - window]
            if not np.isnan(old_value):
                for idx in range(len(window_values)):
                    if window_values[idx] == old_value:
                        window_values.pop(idx)
                        break

        # 计算最大值及其索引
        if len(window_values) >= window:
            max_val = window_values[0]
            max_idx = 0
            for j in range(1, len(window_values)):
                if window_values[j] > max_val:
                    max_val = window_values[j]
                    max_idx = j
            result[i] = max_idx
        else:
            result[i] = np.nan

    return result

@njit
def ts_zscore(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker 切换，重置起点
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        # 判断窗口是否满足
        if i - start_pos + 1 >= window:
            sum_val = 0.0
            count = 0

            # 计算窗口内的均值
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    sum_val += val
                    count += 1

            if count > 1:
                mean = sum_val / count
                var_sum = 0.0

                # 计算窗口内的方差
                for j in range(i - window + 1, i + 1):
                    val = values[j]
                    if not np.isnan(val):
                        var_sum += (val - mean) ** 2

                std = (var_sum / (count - 1)) ** 0.5

                if std > 0.0 and not np.isnan(values[i]):
                    result[i] = (values[i] - mean) / std
                else:
                    result[i] = np.nan
            else:
                result[i] = np.nan
        else:
            result[i] = np.nan

    return result

@njit
def rolling_sum(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker 切换，重置起点
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        # 判断窗口是否满足
        if i - start_pos + 1 >= window:
            s = 0.0
            valid_count = 0

            # 计算窗口内的和
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    s += val
                    valid_count += 1

            # 如果窗口内有效值足够，则返回总和
            if valid_count == window:
                result[i] = s
            else:
                result[i] = np.nan
        else:
            result[i] = np.nan

    return result

@njit
def rolling_max(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    
    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker切换，重置起点
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        # 窗口已满
        if i - start_pos + 1 >= window:
            max_value = -np.inf
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    max_value = max(max_value, val)
            result[i] = max_value if max_value != -np.inf else np.nan
        else:
            result[i] = np.nan

    return result

@njit
def rolling_min(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)

    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker切换，重置起点
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        # 窗口已满
        if i - start_pos + 1 >= window:
            min_value = np.inf
            for j in range(i - window + 1, i + 1):
                val = values[j]
                if not np.isnan(val):
                    min_value = min(min_value, val)
            result[i] = min_value if min_value != np.inf else np.nan
        else:
            result[i] = np.nan

    return result

@njit
def safe_delay(values, tickers, window):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)

    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker切换，重置起点
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        # 窗口是否满足
        if i - start_pos >= window:
            result[i] = values[i - window]
        else:
            result[i] = np.nan

    return result

@njit
def ts_argmaxmin(values, tickers, window):
    result_argmax = rolling_argmax(values, tickers, window)
    result_argmin = rolling_argmax(values, tickers, window)
    result = result_argmax - result_argmin
    result[np.isclose(result, 0.0)] = np.nan
    return result

@njit
def safe_delta(values, tickers, d):
    n = len(values)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)
    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker切换，重置起始位置
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        # 如果距离不足d天，跳过
        if i - start_pos < d:
            result[i] = np.nan
            continue

        v_now = values[i]
        v_before = values[i - d]

        if not np.isnan(v_now) and not np.isnan(v_before):
            result[i] = v_now - v_before
        else:
            result[i] = np.nan

    return result

@njit
def ts_rank(values, dates, window):
    n = len(values)
    result = np.full(n, np.nan)

    start_pos = 0
    last_date = dates[0]

    for i in range(n):
        current_date = dates[i]

        # 日期切换，重置起点
        if current_date != last_date:
            start_pos = i
            last_date = current_date

        # 判断窗口是否满足
        if i - start_pos + 1 >= window:
            valid_count = 0
            window_data = np.empty(window)

            for j in range(i - window + 1, i + 1):
                v = values[j]
                if not np.isnan(v):
                    window_data[valid_count] = v
                    valid_count += 1

            if valid_count > 0:
                sub_data = window_data[:valid_count]
                sorted_data = np.sort(sub_data)
                last_value = values[i]

                if not np.isnan(last_value):
                    # 使用 searchsorted 算 rank（从 1 开始）
                    rank = np.searchsorted(sorted_data, last_value, side='right')
                    result[i] = rank
                else:
                    result[i] = np.nan
            else:
                result[i] = np.nan
        else:
            result[i] = np.nan

    return result

@njit
def ts_cov(x, y, tickers, window):
    n = len(x)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)

    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker 切换，重置起点
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        if i - start_pos + 1 >= window:
            valid_count = 0
            sum_x = 0.0
            sum_y = 0.0

            for j in range(i - window + 1, i + 1):
                val_x = x[j]
                val_y = y[j]
                if not np.isnan(val_x) and not np.isnan(val_y):
                    sum_x += val_x
                    sum_y += val_y
                    valid_count += 1

            if valid_count >= window:
                mean_x = sum_x / valid_count
                mean_y = sum_y / valid_count

                cov_sum = 0.0
                for j in range(i - window + 1, i + 1):
                    val_x = x[j]
                    val_y = y[j]
                    if not np.isnan(val_x) and not np.isnan(val_y):
                        cov_sum += (val_x - mean_x) * (val_y - mean_y)

                result[i] = cov_sum / (valid_count - 1)
            else:
                result[i] = np.nan
        else:
            result[i] = np.nan

    return result

@njit
def ts_corr(x, y, tickers, window):
    n = len(x)
    result = np.full(n, np.nan)
    tickers = tickers.astype(np.int64)

    start_pos = 0
    last_ticker = tickers[0]

    for i in range(n):
        current_ticker = tickers[i]

        # ticker切换，重置起点
        if current_ticker != last_ticker:
            start_pos = i
            last_ticker = current_ticker

        if i - start_pos + 1 >= window:
            valid_count = 0
            sum_x = 0.0
            sum_y = 0.0

            for j in range(i - window + 1, i + 1):
                val_x = x[j]
                val_y = y[j]
                if not np.isnan(val_x) and not np.isnan(val_y):
                    sum_x += val_x
                    sum_y += val_y
                    valid_count += 1

            if valid_count >= window:
                mean_x = sum_x / valid_count
                mean_y = sum_y / valid_count

                cov_sum = 0.0
                var_x_sum = 0.0
                var_y_sum = 0.0

                for j in range(i - window + 1, i + 1):
                    val_x = x[j]
                    val_y = y[j]
                    if not np.isnan(val_x) and not np.isnan(val_y):
                        dx = val_x - mean_x
                        dy = val_y - mean_y
                        cov_sum += dx * dy
                        var_x_sum += dx * dx
                        var_y_sum += dy * dy

                if var_x_sum > 0.0 and var_y_sum > 0.0:
                    std_x = (var_x_sum / (valid_count - 1)) ** 0.5
                    std_y = (var_y_sum / (valid_count - 1)) ** 0.5
                    result[i] = (cov_sum / (valid_count - 1)) / (std_x * std_y)
                else:
                    result[i] = np.nan
            else:
                result[i] = np.nan
        else:
            result[i] = np.nan

    return result
    
@njit
def safe_absolute(arr):
    result = np.abs(arr)
    return result

@njit
def safe_sqrt(arr):
    result = np.sqrt(np.abs(arr))
    return result

@njit
def safe_inverse(arr):
    result = np.empty_like(arr)
    for i in range(len(arr)):
        if arr[i] == 0 or np.isnan(arr[i]):
            result[i] = np.nan
        else:
            result[i] = 1.0 / arr[i]
    return result

@njit
def safe_log(arr):
    result = np.empty_like(arr)
    for i in range(len(arr)):
        if np.isnan(arr[i]):
            result[i] = np.nan
        else:
            result[i] = np.log1p(np.abs(arr[i]))
    return result

@njit
def x_rank(values, dates):
    result = np.empty(len(values))
    result[:] = np.nan

    start = 0
    n = len(values)

    for i in range(1, n):
        if dates[i] != dates[i - 1]:
            # rank from start to i-1
            idx = np.argsort(values[start:i])
            ranks = np.empty(i - start)
            for j in range(len(idx)):
                ranks[idx[j]] = j + 1  # rank starts from 1
            result[start:i] = ranks
            start = i

    # 最后一个 date 段落别忘了 rank
    idx = np.argsort(values[start:n])
    ranks = np.empty(n - start)
    for j in range(len(idx)):
        ranks[idx[j]] = j + 1
    result[start:n] = ranks

    return result

@njit
def safe_add(x, y):
    n = len(x)
    result = np.empty(n, dtype=np.float64)
    for i in range(n):
        a, b = x[i], y[i]
        if np.isnan(a) or np.isnan(b):
            result[i] = np.nan
        else:
            result[i] = a + b
    return result

@njit
def safe_sub(x, y):
    n = len(x)
    result = np.empty(n, dtype=np.float64)
    for i in range(n):
        a, b = x[i], y[i]
        if np.isnan(a) or np.isnan(b):
            result[i] = np.nan
        else:
            result[i] = a - b
    return result

@njit
def safe_mul(x, y):
    n = len(x)
    result = np.empty(n, dtype=np.float64)
    for i in range(n):
        a, b = x[i], y[i]
        if np.isnan(a) or np.isnan(b):
            result[i] = np.nan
        else:
            result[i] = a * b
    return result

@njit
def safe_div(x, y):
    n = len(x)
    result = np.empty(n, dtype=np.float64)
    for i in range(n):
        a, b = x[i], y[i]
        if np.isnan(a) or np.isnan(b) or b == 0:
            result[i] = np.nan
        else:
            result[i] = a / b
    return result
    
def rank_sub(x, y, cross_dates):
    if callable(x): x = x()
    result_x = x_rank(x, cross_dates)  # 只传需要的参数

    if callable(y): y = y()
    result_y = x_rank(y, cross_dates)  # 只传需要的参数

    result = result_x - result_y
    result[np.isclose(result, 0.0)] = np.nan

    return result

def rank_div(x, y, cross_dates):
    if callable(x): x = x()
    result_x = x_rank(x, cross_dates)  # 只传需要的参数

    if callable(y): y = y()
    result_y = x_rank(y, cross_dates)  # 只传需要的参数

    result = result_x / result_y
    result[np.isclose(result, 0.0)] = np.nan

    return result

@njit
def sigmoid(x):
    result = np.empty_like(x)
    for i in range(len(x)):
        result[i] = 1 / (1.0 + np.exp(-x[i]))
    return result


def ts_zscore_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = ts_zscore(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)
    
def rolling_skew_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_skew(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_mean_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_mean(x.values, ts_tickers, window)
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_std_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_std(x.values, ts_tickers, window)
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_kurt_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_kurt(x.values, ts_tickers, window)
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_prod_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_prod(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_argmin_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_argmin(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_argmax_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_argmax(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_max_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_max(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_min_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_min(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rolling_sum_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = rolling_sum(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_delay_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = safe_delay(x.values, ts_tickers, window)
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def ts_argmaxmin_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = ts_argmaxmin(x.values, ts_tickers, window)
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_delta_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = safe_delta(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def ts_rank_expr(x, window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = ts_rank(x.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def ts_corr_expr(x,y,window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    if callable(y): y = y()
    y = y.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = ts_corr(x.values, y.values, ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def ts_cov_expr(x,y,window):
    if callable(x): x = x()
    x = x.sort_index(level=['ticker', 'date'], sort_remaining=True)
    if callable(y): y = y()
    y = y.sort_index(level=['ticker', 'date'], sort_remaining=True)
    result = ts_cov(x.values, y.values,  ts_tickers, window)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_add_expr(x, y):
    if callable(x): x = x()
    if callable(y): y = y()
    x_arr = x.values.astype(np.float64)
    y_arr = y.values.astype(np.float64)
    return pd.Series(safe_add(x_arr, y_arr), index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_sub_expr(x, y):
    if callable(x): x = x()
    if callable(y): y = y()
    x_arr = x.values.astype(np.float64)
    y_arr = y.values.astype(np.float64)
    return pd.Series(safe_sub(x_arr, y_arr), index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_mul_expr(x, y):
    if callable(x): x = x()
    if callable(y): y = y()
    x_arr = x.values.astype(np.float64)
    y_arr = y.values.astype(np.float64)
    return pd.Series(safe_mul(x_arr, y_arr), index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_div_expr(x, y):
    if callable(x): x = x()
    if callable(y): y = y()
    x_arr = x.values.astype(np.float64)
    y_arr = y.values.astype(np.float64)
    return pd.Series(safe_div(x_arr, y_arr), index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rank_div_expr(x, y):
    if callable(x): x = x()
    if callable(y): y = y()
    
    # 保持 Pandas Series 结构
    result = rank_div(x.values, y.values, cross_dates)
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def rank_sub_expr(x, y):
    if callable(x): x = x()
    if callable(y): y = y()
    
    # 保持 Pandas Series 结构
    result = rank_sub(x.values, y.values, cross_dates)
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_absolute_expr(x):
    if callable(x): x = x()
    result = safe_absolute(x.values)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_sqrt_expr(x):
    if callable(x): x = x()
    result = safe_sqrt(x.values)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_inverse_expr(x):
    if callable(x): x = x()
    result = safe_inverse(x.values)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def safe_log_expr(x):
    if callable(x): x = x()
    result = safe_log(x.values)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def x_rank_expr(x):
    if callable(x): x = x()
    result = x_rank(x.values, cross_dates)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

def sigmoid_expr(x):
    if callable(x): x = x()
    result = sigmoid(x.values)  # 只传需要的参数
    return pd.Series(result, index=x.index).sort_index(level=['date', 'ticker'], sort_remaining=True)

In [35]:
class Expr: pass

TERMINALS =full_data_factors.columns.tolist()
terminal_dict = {name: full_data_factors[name] for name in TERMINALS}

pset = gp.PrimitiveSetTyped("MAIN", [], Expr)

# Updated context_wrapper
def context_wrapper(data, name):
    def wrapped(**kwargs):
        # print(f"Attempting to access: {name}")
        if name in data.columns:
            return data[name]  # Return the column data
        else:
            raise ValueError(f"Column '{name}' not found in the data.")
    wrapped.__name__ = name  # Ensure the callable has a unique name
    return wrapped
    
cross_tickers = full_data_factors.index.get_level_values('ticker').values.astype(np.int64)
cross_dates = full_data_factors.index.get_level_values('date').values.astype(np.int64)

full_data_factors_ts = full_data_factors.sort_index(level=['ticker', 'date'], sort_remaining=True)
ts_tickers = full_data_factors_ts.index.get_level_values('ticker').values.astype(np.int64)
ts_dates = full_data_factors_ts.index.get_level_values('date').values.astype(np.int64)

pset.context.update({"factor_columns": full_data_factors, "np": np, "random": random})
for name in TERMINALS:
    pset.addTerminal(context_wrapper(full_data_factors, name), Expr, name=name)

# ================================================= 测试成功 ===============================================#
"""
时序
"""
pset.addPrimitive(rolling_skew_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_mean_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_std_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_sum_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_max_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_min_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_kurt_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_prod_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_argmin_expr, [Expr, int], Expr) 
pset.addPrimitive(rolling_argmax_expr, [Expr, int], Expr) 
pset.addPrimitive(safe_delay_expr, [Expr, int], Expr) 
pset.addPrimitive(ts_zscore_expr, [Expr, int], Expr) 
pset.addPrimitive(ts_argmaxmin_expr, [Expr, int], Expr) 
pset.addPrimitive(safe_delta_expr, [Expr, int], Expr) 
pset.addPrimitive(ts_rank_expr, [Expr, int], Expr) 

pset.addPrimitive(ts_cov_expr, [Expr, Expr, int], Expr)
pset.addPrimitive(ts_corr_expr, [Expr, Expr, int], Expr)

"""
截面
"""
pset.addPrimitive(safe_absolute_expr, [Expr], Expr) 
pset.addPrimitive(safe_sqrt_expr, [Expr], Expr)
pset.addPrimitive(safe_inverse_expr, [Expr], Expr)
pset.addPrimitive(safe_log_expr, [Expr], Expr)
pset.addPrimitive(x_rank_expr, [Expr], Expr)
pset.addPrimitive(sigmoid_expr, [Expr], Expr)

pset.addPrimitive(rank_sub_expr, [Expr, Expr], Expr)
pset.addPrimitive(rank_div_expr, [Expr, Expr], Expr)

pset.addPrimitive(safe_add_expr, [Expr, Expr], Expr)
pset.addPrimitive(safe_sub_expr, [Expr, Expr], Expr)
pset.addPrimitive(safe_mul_expr, [Expr, Expr], Expr)
pset.addPrimitive(safe_div_expr, [Expr, Expr], Expr)

# ================================================= 测试成功 ===============================================#

# ================================================= 待测试 ===============================================#

# ================================================= 待测试 ===============================================#

def validate_expression(expr):
    terminals = [str(term) for term in expr]  # Convert terminals to string
    return len(terminals) == len(set(terminals))  # Return True if no duplicates

def generate_valid_expr(pset=pset, type_=Expr):  
    expr = gp.genHalfAndHalf(pset=pset, min_=1, max_=4, type_=type_)

    while not validate_expression(expr):
        expr = gp.genHalfAndHalf(pset=pset, min_=1, max_=4, type_=type_)

    return expr
    
def rand_window():
    return random.randint(2, 30)
pset.addEphemeralConstant("rand_window", rand_window, int)

def const_3() -> int: return 3 
def const_5() -> int: return 5
def const_10() -> int: return 10  
def const_20() -> int: return 20  
def const_60() -> int: return 60 

pset.addPrimitive(const_3, [], int)
pset.addPrimitive(const_5, [], int)
pset.addPrimitive(const_10, [], int)
pset.addPrimitive(const_20, [], int)
pset.addPrimitive(const_60, [], int)

# 注册工具箱
for name in ['FitnessMulti', 'Individual']:
    if hasattr(creator, name):
        delattr(creator, name)
        
'''Multi Fitnesses'''       
# creator.create("FitnessMulti", base.Fitness, weights=(1.0,1.0,))
# creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMulti)

'''Signle Fitness'''   
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMax)

toolbox = base.Toolbox()
toolbox.register("expr", generate_valid_expr)
toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.expr)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

In [36]:
####### ICIR #########  
def cal_icir(factor_clean_rank, clean_target):

    # factor_rank_train = factor_values.unstack().rank(ascending=True, axis=1)
    IC = (factor_clean_rank.corrwith(clean_target, axis=1)).dropna()
    icir = IC.mean()/IC.std()
    return abs(icir)

####### Mutual Info #########    
def calc_mutual_info(factor_clean_rank, clean_target):
    aligned_factor, aligned_return = factor_clean_rank.align(clean_target, join='inner', axis=0)
    aligned_factor, aligned_return = aligned_factor.align(aligned_return, join='inner', axis=1)

    x = aligned_factor.values.flatten()
    y = aligned_return.values.flatten()
    
    mask = np.isfinite(x) & np.isfinite(y)
    x_clean = x[mask].reshape(-1, 1)
    y_clean = y[mask]

    mi = mutual_info_regression(x_clean, y_clean, discrete_features=False)
    return mi[0]

####### Turnover Mean ######### 
@njit
def calculate_turnover_long_mode(df_clean, top_pct=0.1):
    n_days, n_stocks = factor_values.shape
    turnover = np.zeros(n_days)

    prev_signal = np.zeros(n_stocks)

    for i in range(n_days):
        row = factor_values[i, :]
        valid_mask = ~np.isnan(row)
        valid_values = row[valid_mask]

        n_valid = valid_values.size
        if n_valid == 0:
            turnover[i] = 0.0
            continue

        n_top = max(1, int(n_valid * top_pct))

        # 排名：小的排前，取top n_top 个索引
        sorted_indices = np.argsort(valid_values)
        selected_indices = np.flatnonzero(valid_mask)[sorted_indices[:n_top]]

        # 构建当日信号
        signal = np.zeros(n_stocks)
        signal[selected_indices] = 1

        # 换手率 = 当日买卖信号变化数量 / 2
        turnover[i] = np.abs(signal - prev_signal).sum() / 2

        # 更新前一日信号
        prev_signal = signal

    turnover_mean = turnover.mean()
    return turnover_mean

@njit
def calculate_turnover_weight_mode(factor_values, n_range=10):
    n_days, n_stocks = factor_values.shape
    turnover_series = np.empty(n_days)
    turnover_series[:] = np.nan

    weight_all = np.zeros((n_days, n_stocks))

    for day in range(n_days):
        # 手动计算每行的最大值，忽略NaN
        row_max = np.nanmax(factor_values[day])

        if np.isnan(row_max):
            continue

        for layer in range(1, n_range+1):
            lower = (1/n_range) * (layer-1) * row_max
            upper = (1/n_range) * layer * row_max if layer < n_range else row_max + 1

            # mask layer
            mask = (factor_values[day] >= lower) & (factor_values[day] < upper)

            stock_count = np.sum(mask)
            if stock_count > 0:
                weight_all[day, mask] = 1.0 / stock_count

    # 换手率 = 权重变动绝对值之和
    turnover_series[1:] = np.sum(np.abs(weight_all[1:] - weight_all[:-1]), axis=1)

    return np.nanmean(turnover_series)

In [37]:
alpha = 0.005

def evaluateFitness(individual): #  train_data_set, train_clean_rank_target, val_data_set, val_clean_rank_target
    ###===== 生成公式及因子
    print(f"Compiled expression: {individual}")
    func = gp.compile(expr=individual, pset=pset)
    factor_values = func
    
    ###===== 对因子值做截面排序
    factor_values = factor_values.replace(0, np.nan).unstack().replace([np.inf, -np.inf], np.nan)
    df_clean = (factor_values * final_universe).dropna(how='all', axis=0)
    full_data_clean_rank_factor = df_clean.rank(ascending=True, axis=1, method='first')

    ###===== 加入complexity penalty
    # fitness = len(individual)
    # penalty = alpha * len(individual)
    
    ###===== 1. 验证集ICIR fitness
    val_clean_rank_factor = full_data_clean_rank_factor[(full_data_clean_rank_factor.index > '2021-06-30')]

    icir_val = cal_icir(val_clean_rank_factor, val_clean_rank_target)

    factor_array = df_clean.to_numpy(dtype=np.float64)
    to_mean_val = calculate_turnover_weight_mode(factor_array)
    
    fitness1 = to_mean_val
    fitness2 = icir_val
    fitness = icir_val/to_mean_val
    
    return (fitness,)
   
ind1 = toolbox.individual()
print (f"适应度\n：", evaluateFitness(ind1))

Compiled expression: rolling_kurt_expr(ts_rank_expr(rank_sub_expr(open, twap), const_60()), const_3())
适应度
： (0.2546448182748728,)


In [41]:
def eaMuPlusLambda_with_early_stopping(
    population, toolbox, mu, lambda_, cxpb, mutpb, ngen, 
    stats=None, halloffame=None, verbose=__debug__,
    patience=5, delta=1e-4):

    logbook = tools.Logbook()
    logbook.header = ['gen', 'nevals'] + (stats.fields if stats else [])

    best_fitness = -np.inf
    generations_no_improve = 0

    # 去重函数
    def deduplicate_population(pop):
        expr_set = set()
        unique_pop = []
        for ind in pop:
            expr_str = str(ind)
            if expr_str not in expr_set:
                expr_set.add(expr_str)
                unique_pop.append(ind)
        return unique_pop

    # 初始评估
    fitnesses = list(toolbox.map(toolbox.evaluate, population))
    for ind, fit in zip(population, fitnesses):
        ind.fitness.values = fit

    # 更新 HOF
    if halloffame is not None:
        halloffame.update(population)

    # 初始去重
    population = deduplicate_population(population)

    # 记录初代统计
    valid_population = [ind for ind in population if ind.fitness.valid]
    record = stats.compile(valid_population) if stats and len(valid_population) > 0 else {}
    logbook.record(gen=0, nevals=len(population), **record)
    if verbose:
        print(logbook.stream)

    # 主迭代
    for gen in range(1, ngen + 1):

        # 生成 offspring
        offspring = toolbox.select(population, lambda_)
        offspring = list(map(toolbox.clone, offspring))

        # 交叉
        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if np.random.rand() < cxpb:
                toolbox.mate(child1, child2)
                del child1.fitness.values
                del child2.fitness.values

        # 变异
        for mutant in offspring:
            if np.random.rand() < mutpb:
                toolbox.mutate(mutant)
                del mutant.fitness.values

        # 评估 offspring 中的无效个体
        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        for ind in invalid_ind:
            try:
                fit = toolbox.evaluate(ind)
                ind.fitness.values = fit
            except Exception as e:
                print(f"[Invalid Expression in Evolution] {str(ind)} | Error: {str(e)}")
                ind.fitness.values = (0.0, )

        # 合并选择下一代 population
        population[:] = toolbox.select(population + offspring, mu)

        # 去重
        population = deduplicate_population(population)

        # 更新 valid_population
        valid_population = [ind for ind in population if ind.fitness.valid]

        # 更新 HOF
        if halloffame is not None:
            try:
                halloffame.update(valid_population)
            except Exception as e:
                print(f"[HOF Update Error] {str(e)}")

        # 记录统计
        record = stats.compile(valid_population) if stats and len(valid_population) > 0 else {}
        logbook.record(gen=gen, nevals=len(invalid_ind), **record)
        if verbose:
            print(logbook.stream)

        # Early Stopping 检查
        if valid_population:
            current_best = max(ind.fitness.values[0] for ind in valid_population)
            if current_best - best_fitness > delta:
                best_fitness = current_best
                generations_no_improve = 0
            else:
                generations_no_improve += 1

            if generations_no_improve >= patience:
                print(f"Early stopping triggered at generation {gen}. No significant improvement in {patience} generations.")
                break

    return population, logbook
    
def initialize_population_with_screening(toolbox, n_population, fitness_threshold=0, length_threshold=10):
    raw_population = toolbox.population(n=3 * n_population)
    fitness_and_len = []

    for ind in raw_population:
        try:
            fit = toolbox.evaluate(ind)
            ind.fitness.values = fit
            expr_len = len(ind)
            fitness_and_len.append((ind, fit[0], expr_len))
        except Exception as e:
            print(f"[Invalid Expression in Init] {str(ind)} | Error: {str(e)}")
            ind.fitness.values = (0.0,)
            fitness_and_len.append((ind, 0.0, 999))  # 长度设大，保证筛掉

    filtered_population = [
        ind for ind, rot, length in fitness_and_len
        if (rot >= fitness_threshold and length <= length_threshold)
    ]

    print(f"筛选后剩余 {len(filtered_population)} 个个体")

    # if len(filtered_population) >= n_population:
    #     population = random.sample(filtered_population, n_population)
    # else:
    #     population = filtered_population.copy()
    #     while len(population) < n_population:
    #         new_ind = toolbox.individual()
    #         try:
    #             fit = toolbox.evaluate(new_ind)
    #             new_ind.fitness.values = fit
    #             if fit[0] >= fitness_threshold and fit[1] <= length_threshold:
    #                 population.append(new_ind)
    #         except Exception as e:
    #             print(f"[Invalid Expression in Refill] {str(new_ind)} | Error: {str(e)}")
    #             new_ind.fitness.values = (0.0,)

    # print(f"最终初始化种群数量: {len(population)}")
    return population


def evaluate_individual(ind):
    try:
        return evaluateFitness(ind)
    except Exception as e:
        print(f"[Invalid Individual] {str(ind)} | Error: {str(e)}")
        return (0.0,)

toolbox.register("evaluate", evaluate_individual)
toolbox.register("mate", gp.cxOnePointLeafBiased, termpb=0.2)
toolbox.register("mutate", gp.mutUniform, expr=toolbox.expr, pset=pset)
toolbox.register("select", tools.selTournament, tournsize=20)

## 初始化种群
population = initialize_population_with_screening(toolbox, n_population=500, # 筛选初始化种群 informed init population
                                                  fitness_threshold=0, length_threshold=9) # ideal ：icir=2 turnovers=0.3 fitness = 2/0.3
hof = tools.HallOfFame(100, similar=np.array_equal)

stats_multi = tools.Statistics(lambda ind: ind.fitness.values)

stats_multi.register("ICIR/TO_avg", lambda values: np.mean([v[0] for v in values]))
stats_multi.register("ICIR/TO_max", lambda values: np.max([v[0] for v in values]))
stats_multi.register("ICI/TO_min", lambda values: np.min([v[0] for v in values]))

population, logbook = eaMuPlusLambda_with_early_stopping(
    population, toolbox,
    mu=200, lambda_=150,
    cxpb=0.4, mutpb=0.2,
    ngen=100,  
    stats=stats_multi,
    halloffame=hof,
    verbose=True,
    patience=10,  # 连续 10 代无明显提升就停
    delta=1e-4   # max fitness 改进小于 0.0001 判定为无明显提升
)

Compiled expression: rolling_sum_expr(vwap, 8)
Compiled expression: safe_delta_expr(ts_corr_expr(sigmoid_expr(turnover_rate), rolling_min_expr(high, 16), const_60()), const_10())
Compiled expression: safe_mul_expr(rolling_min_expr(low, 24), safe_inverse_expr(high))
Compiled expression: rolling_std_expr(ts_argmaxmin_expr(rolling_sum_expr(vwap, 25), const_5()), 25)
Compiled expression: rolling_prod_expr(volume, 28)
Compiled expression: rolling_skew_expr(total_turnover, 3)
Compiled expression: rolling_kurt_expr(rolling_prod_expr(rank_div_expr(rolling_std_expr(twap, 16), rolling_mean_expr(total_turnover, 20)), const_5()), const_10())
Compiled expression: safe_sqrt_expr(ts_rank_expr(rolling_max_expr(adj_return, 23), const_20()))
Compiled expression: rolling_mean_expr(ts_corr_expr(rolling_kurt_expr(high, 23), safe_add_expr(vwap, twap), const_20()), const_3())
Compiled expression: rolling_min_expr(close, const_60())
Compiled expression: safe_sub_expr(safe_inverse_expr(vwap), rolling_prod_expr

ValueError: not enough values to unpack (expected 3, got 2)

In [None]:
if not logbook:
    print("Logbook is empty, no generations recorded.")
else:
    print("Logbook last entry:", logbook[-1])
    # print("[{}] best_score: {}".format(logbook[-1]['gen'], logbook[-1]['max','N/A']))  # Use 'max' instead of 'min'

# Print all individuals stored in Hall of Fame (hof)
print("Best Individuals in Hall of Fame:")
for ind in hof:
    print("Individual:", ind)
    print("Fitness:", ind.fitness.values)
    print("----------")

In [None]:
def plot_convergence(logbook, patience, delta):
    gen = logbook.select("gen")
    max_fitness = logbook.select("max")

    plt.figure(figsize=(12, 6))
    plt.plot(gen, max_fitness, label="Max Fitness (Validation Set)", marker='o')

    # 标记 early stopping 发生点
    best_fitness = -np.inf
    no_improve_count = 0
    stop_gen = None

    for i in range(len(max_fitness)):
        if max_fitness[i] > best_fitness + delta:
            best_fitness = max_fitness[i]
            no_improve_count = 0
        else:
            no_improve_count += 1

        if no_improve_count >= patience:
            stop_gen = gen[i]
            break

    if stop_gen is not None:
        plt.axvline(x=stop_gen, color='r', linestyle='--', label=f'Early Stopping at Gen {stop_gen}')
        plt.scatter(stop_gen, max_fitness[gen.index(stop_gen)], color='red', zorder=5)

    plt.xlabel("Generation")
    plt.ylabel("Max Fitness (Validation Set ICIR)")
    plt.title("GP Convergence on Validation Set with Early Stopping")
    plt.legend()
    plt.grid(True)
    plt.show()
    
plot_convergence(logbook, patience=10, delta=1e-4)