In [1]:
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def add_project_root(project_name: str = "extrema_lab"):
    cwd = os.getcwd()
    path_parts = cwd.split(os.sep)

    for i in range(len(path_parts), 0, -1):
        potential_root = os.sep.join(path_parts[:i])
        if os.path.basename(potential_root) == project_name:
            root = os.path.dirname(potential_root)
            if root not in sys.path:
                sys.path.append(root)
            print(f"已添加项目根目录到 sys.path: {root}")
            return

    print(f"未找到项目 {project_name}，请确认路径是否正确")


add_project_root("extrema_lab")
from extrema_lab.feature_eng.operator.utils_tools import *

已添加项目根目录到 sys.path: C:\quant\work


In [2]:
import json
with open("../../symbols.json", "r", encoding="utf-8") as f:
    symbols_list = json.load(f)

symbols_list_usdt = [s if s.endswith("T") else s + "T" for s in symbols_list]

feat_cal_window = 5000
feat_norm_window = 2000
feat_norm_rolling_mean_window = 500

special_tokens = {"BTCUSDT", "BNBUSDT"}

default_threshold = 0.0067

symbol_params = {
    sym: {
        "threshold": str(0.0031 if sym in special_tokens else default_threshold),
        "feat_cal_window": feat_cal_window,
        "feat_norm_window": feat_norm_window,
        "feat_norm_rolling_mean_window": feat_norm_rolling_mean_window,
    }
    for sym in symbols_list_usdt
}

symbol_dfs = process_all_symbols(symbol_params)
long_df = build_long_cross_sections_fast(symbol_dfs)


FileNotFoundError: 系统找不到指定的文件。 (os error 2): C:\quant\work\extrema_lab\data_proc\resampled_data\AXSUSDT_merged_thr0.0067.parquet

This error occurred with the following context stack:
	[1] 'parquet scan'
	[2] 'sink'


In [None]:
del symbol_dfs  # 删除变量引用
gc.collect()

In [None]:
print(long_df)

In [None]:
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, roc_auc_score
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.mixture import GaussianMixture
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.preprocessing import LabelEncoder

import lightgbm as lgb
from catboost import CatBoostClassifier
import xgboost as xgb

from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from pytorch_tabnet.tab_model import TabNetRegressor
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.callbacks import Callback
from pytorch_tabnet.callbacks import History
from pytorch_tabnet.callbacks import EarlyStopping
from sklearn.preprocessing import LabelEncoder
import torch

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
N = 3000
target_col = f"future_return_{N}"
exclude_prefixes = ['px', 'timestamp', 'timestamp_dt', 'symbol']


In [None]:
le_symbol = LabelEncoder()
long_df = long_df.with_columns(pl.col('symbol').str.to_uppercase())

all_symbols = set()
all_symbols.update(long_df["symbol"].unique())
le_symbol.fit(sorted(list(all_symbols)))

symbol_encoded = le_symbol.transform(long_df['symbol'].to_list())
long_df = long_df.with_columns([pl.Series('enc_cat_symbol', symbol_encoded)])

symbol_to_id = dict(zip(le_symbol.classes_, le_symbol.transform(le_symbol.classes_)))
id_to_symbol = {v: k for k, v in symbol_to_id.items()}

# 加 row_nr
long_temp_df = long_df.with_row_index(name="row_nr")
del long_df
gc.collect() 

In [None]:
df_with_future = (
    long_temp_df.sort(["symbol", "timestamp"])
    .group_by("symbol", maintain_order=True)
    .map_groups(lambda g: g.with_columns([
        pl.col("px").shift(-N).alias("px_future"),
        (pl.col("px").shift(-N) / pl.col("px")).log().alias(f"future_return_{N}")
    ]))
    .sort("row_nr")
    .drop("row_nr")
)


In [None]:
# 在 long format 下，按 timestamp 做截面标准化 & 排序
df_with_future = df_with_future.with_columns([
    # 截面 z-score
    ((pl.col(c) - pl.col(c).mean().over("timestamp")) /
     pl.when(pl.col(c).std().over("timestamp") > 1e-9)
       .then(pl.col(c).std().over("timestamp"))
       .otherwise(1)
    ).alias(f"{c}_zscore_cs")
    for c in time_series_feature_cols
])# + [
#     # 截面 rank（归一化到 [0,1]）
#     (pl.col(c).rank("average").over("timestamp") /
#      pl.len().over("timestamp")).alias(f"{c}_rank_cs")
#     for c in time_series_feature_cols
# ])

In [None]:
# 把 enc_cat_symbol 移到最后一列
cols = [c for c in df_with_future.columns if c != "enc_cat_symbol"] + ["enc_cat_symbol"]
df_with_future = df_with_future.select(cols)

In [None]:
df_with_future = clean_df_drop_nulls(df_with_future)
split_dataframes = split_df_by_month(df_with_future)  # 只拿 list

In [None]:
print(symbol_to_id) 
print(id_to_symbol)

In [None]:
print(split_dataframes[-1])

In [None]:
# 用第一个 df 定义 feature_cols
sample_df = split_dataframes[0]

cat_idxs = [feature_cols.index('enc_cat_symbol')]
cat_dims = [sample_df.select('enc_cat_symbol').n_unique()]
cat_emb_dim = 16
print(len(feature_cols), cat_idxs, cat_dims)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_cross_section(symbols, y_true, y_binary, y_pred_prob, px, alpha=1.0):
    x = np.arange(len(symbols))

    fig, ax1 = plt.subplots(figsize=(12, 6))

    width = 0.2

    # 真实未来收益（连续）
    ax1.bar(x - width, y_true, width=width, label='Future Return', alpha=0.6)
    ax1.set_ylabel('Future Return')

    # 价格线
    ax2 = ax1.twinx()
    ax2.plot(x, px, label='Price', color='tab:blue', marker='o')
    ax2.set_ylabel('Price')

    # 分类标签（二分类）
    ax3 = ax1.twinx()
    ax3.spines.right.set_position(("outward", 60))
    ax3.scatter(x, y_binary, label='GMM Label', color='tab:orange', marker='x')
    ax3.set_ylim(-0.1, 1.1)
    ax3.set_ylabel('Binary Label')

    # 预测概率
    ax4 = ax1.twinx()
    ax4.spines.right.set_position(("outward", 120))
    ax4.plot(x, y_pred_prob, label='Predicted Prob', color='tab:green', marker='^')
    ax4.set_ylim(-0.05, 1.05)
    ax4.set_ylabel('Predicted Probability')

    ax1.set_xticks(x)
    ax1.set_xticklabels(symbols, rotation=45)
    ax1.set_xlabel('Symbols')

    # 合并图例
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    lines_3, labels_3 = ax3.get_legend_handles_labels()
    lines_4, labels_4 = ax4.get_legend_handles_labels()

    ax1.legend(
        lines_1 + lines_2 + lines_3 + lines_4,
        labels_1 + labels_2 + labels_3 + labels_4,
        loc='upper left'
    )

    plt.title("Cross-Section Comparison at One Timestamp")
    plt.tight_layout()
    plt.show()


In [None]:
class MinEpochsEarlyStopping(EarlyStopping):
    def __init__(self, early_stopping_metric, patience, min_epochs=5, is_maximize=False, tol=0.0):
        super().__init__(
            early_stopping_metric=early_stopping_metric,
            patience=patience,
            is_maximize=is_maximize,
            tol=tol
        )
        self.min_epochs = min_epochs
        self._callback_reset_flag = False

    def on_epoch_end(self, epoch, logs):
        # 在 min_epochs 之前，不触发早停逻辑
        if epoch < self.min_epochs:
            return

        self.stopped_epoch = epoch
        
        # 调用父类逻辑继续正常早停检查
        super().on_epoch_end(max(epoch, self.min_epochs), logs)
        
        # 第一次达到 min_epochs 时重置 baseline
        if not self._callback_reset_flag:
            print(f"[Check] Current best_epoch={self.best_epoch}, best_loss={self.best_loss:.6f}")

            if self.early_stopping_metric not in logs:
                raise KeyError(f"Metric '{self.early_stopping_metric}' not found in logs keys={list(logs.keys())}")
            
            self.best_loss = logs[self.early_stopping_metric]
            self.best_epoch = epoch

            # preventing double log bug
            self._callback_reset_flag = True
            print(f"[MinEpochsEarlyStopping] Reset best_score at epoch {epoch} to {self.best_loss:.6f}")

        print(f"[DEBUG] After super(): best_loss={self.best_loss:.6f}, best_epoch={self.best_epoch}, wait={self.wait}")


In [None]:
n_train_weeks = 8 # 可配置
n_val_weeks = 1    # 一般 1 周验证
n_test_weeks = 1   # 后 1 周做 test

tabnet = None

all_preds = []  # 放到 for 循环外

for i in range(len(weekly_dataframes) - n_train_weeks - n_val_weeks - n_test_weeks + 1):
    train_dfs = weekly_dataframes[i : i + n_train_weeks]
    val_dfs = weekly_dataframes[i + n_train_weeks : i + n_train_weeks + n_val_weeks]
    test_dfs = weekly_dataframes[i + n_train_weeks + n_val_weeks : i + n_train_weeks + n_val_weeks + n_test_weeks]

    train_df = pl.concat(train_dfs)
    val_df = pl.concat(val_dfs)
    test_df = pl.concat(test_dfs)
    
    def process_df_np(df):
        df = df.sort('timestamp').drop_nulls(subset=feature_cols + [target_col, 'px'])
        X = df.select(feature_cols).to_numpy()  # Polars DataFrame 转 numpy ndarray
        y = df.select(target_col).to_numpy().reshape(-1, 1)
        px = df.select('px').to_numpy()
        ts = df.select('timestamp').to_numpy()
        symbol_enc = df.select("enc_cat_symbol")
        return X, y, px, ts, symbol_enc

    X_train, y_train, px_train, ts_train, sb_train = process_df_np(train_df)
    X_val, y_val, px_val, ts_val, sb_val = process_df_np(val_df)
    X_test, y_test, px_test, ts_test, sb_test = process_df_np(test_df)


    print("=" * 60)
    print(f"Fold {i}: Train {i}~{i+n_train_weeks-1}, Val {i+n_train_weeks}, Test {i+n_train_weeks+1}")
    print("Train:", train_df['timestamp_dt'][0], "to", train_df['timestamp_dt'][-1])
    print("Val:", val_df['timestamp_dt'][0], "to", val_df['timestamp_dt'][-1])
    print("Test:", test_df['timestamp_dt'][0], "to", test_df['timestamp_dt'][-1])
    

    params = {
        # 模型结构参数
        "n_d": 8,                      # 决策输出维度
        "n_a": 8,                      # 注意力机制维度
        "n_steps": 3,                  # 决策步数
        "gamma": 1.3,                  # 控制特征复用的程度（>1）
        "n_independent": 3,           # 每个 step 的独立 Feature Transformer 层数
        "n_shared": 2,                # 每个 step 的共享 Feature Transformer 层数
    
        # 分类特征嵌入（如果你用的都是 float 特征，可以全留空）
        "cat_idxs": cat_idxs,               # 类别特征的列索引
        "cat_dims": cat_dims,               # 每个类别特征的类别数
        "cat_emb_dim": cat_emb_dim,             # 类别特征的嵌入维度（或 list）
    
        # 正则化与数值稳定性
        "lambda_sparse": 1e-5,        # 稀疏正则
        "epsilon": 1e-15,             # sparsemax 稳定项
        "momentum": 0.03,             # BatchNorm 的动量
        "clip_value": 3.0,            # 梯度裁剪
        
        # 注意力 mask 类型
        "mask_type": "sparsemax",     # sparsemax 或 entmax
    
        # 优化器设置（函数和参数）
        # "optimizer_fn": torch.optim.Adam,    
        "optimizer_params": {"lr": 5e-5},
    
        # 学习率调度器（可选）
        "scheduler_fn": None,         # torch.optim.lr_scheduler.StepLR 等
        "scheduler_params": {},       # 比如 {"step_size": 20, "gamma": 0.95}
    
        # 预训练解码器结构（一般用不到）
        "n_shared_decoder": 1,
        "n_indep_decoder": 1,
    
        # 训练环境和调试
        "seed": 7,
        "verbose": 2,
        "device_name": "cuda",        # auto / cpu / cuda
    }

    init_fit_params = {
        "eval_metric": ['mae'],
        "max_epochs": 20,
        # "patience": 5,
        "batch_size": 2048,
        "virtual_batch_size": 512,
        "compute_importance": False,
    }

    my_early_stopping = MinEpochsEarlyStopping(
        early_stopping_metric='val_0_mae',
        patience=3,
        min_epochs=3,
    )
    
    tabnet = TabNetRegressor(**params )
    tabnet.fit(
        X_train=X_train,
        y_train=y_train,
        eval_set=[(X_val, y_val)],
        callbacks=[my_early_stopping], # 使用你的自定义回调
        **init_fit_params,
    )

    y_pred = tabnet.predict(X_test).squeeze()
    print(ts_test.shape, y_test.shape, y_pred.shape, px_test.shape)

    print(f"MSE: {mean_squared_error(y_test, y_pred):.6f}")
    print(f"MAE: {mean_absolute_error(y_test, y_pred):.6f}")
    current_window_results = {
        'timestamp': ts_test,
        'symbol_enc': sb_test, # 收集价格，回测时需要
        'true_label': y_test,
        'predicted_prob': y_pred,
        'px': px_test, # 收集价格，回测时需要
    }
    
    all_preds.append(current_window_results)



In [None]:
all_preds

In [None]:
print(f"all_preds length: {len(all_preds)}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_cross_section_comparison(symbols, true_labels, pred_probs, prices, std_array=None, alpha=1):
    """
    symbols: list/array of symbol names或编码（横轴）
    true_labels: array，对应每个币种的真实标签
    pred_probs: array，对应每个币种的预测概率
    prices: array，对应每个币种的价格
    std_array: array，可选，价格的波动区间
    alpha: 标准差放大倍数
    """
    x = np.arange(len(symbols))

    fig, ax1 = plt.subplots(figsize=(16, 7))

    # 真实标签（可以用点图）
    ax1.scatter(x, true_labels, label="True Label", color='tab:blue', marker='o', s=50, alpha=0.7)
    ax1.set_ylabel("True Label", color='tab:blue')
    ax1.tick_params(axis='y', labelcolor='tab:blue')
    ax1.set_ylim(min(true_labels)*1.1, max(true_labels)*1.1)

    # 预测概率
    ax2 = ax1.twinx()
    ax2.plot(x, pred_probs, label="Predicted Probability", color='tab:green', marker='x', linestyle='-', alpha=0.7)
    ax2.set_ylabel("Predicted Probability", color='tab:green')
    ax2.tick_params(axis='y', labelcolor='tab:green')
    ax2.set_ylim(0, 1)

    # 价格
    ax3 = ax1.twinx()
    ax3.spines.right.set_position(("outward", 60))
    ax3.plot(x, prices, label="Price", color='tab:red', linestyle='--', alpha=0.7)

    # 价格区间带
    if std_array is not None:
        ax3.fill_between(x, prices - alpha * std_array, prices + alpha * std_array,
                         color='tab:red', alpha=0.15, label="Price ± std")

    ax3.set_ylabel("Price", color='tab:red')
    ax3.tick_params(axis='y', labelcolor='tab:red')

    # 横轴是币种名
    ax1.set_xticks(x)
    ax1.set_xticklabels(symbols, rotation=45, ha='right')

    plt.title("Cross-Sectional Comparison: True Label, Prediction & Price")
    # 合并图例
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    lines_3, labels_3 = ax3.get_legend_handles_labels()
    ax1.legend(lines_1 + lines_2 + lines_3, labels_1 + labels_2 + labels_3, loc='upper left')

    plt.tight_layout()
    plt.show()


In [None]:
import pandas as pd

df_list = []
for i, result in enumerate(all_preds):
    try:
        # 先处理 symbol_enc 转成 numpy array
        # 假设 result['symbol_enc'] 是 Polars DataFrame，列名是 'enc_cat_symbol'
        if hasattr(result['symbol_enc'], "to_pandas"):
            symbol_enc_array = result['symbol_enc'].to_pandas()['enc_cat_symbol'].values
        else:
            # 如果已经是 np.ndarray 或 list
            symbol_enc_array = result['symbol_enc']

        df = pd.DataFrame({
            'timestamp': result['timestamp'].squeeze(),  # (N,)
            'symbol_enc': symbol_enc_array.squeeze(),    # (N,)
            'true_label': result['true_label'].squeeze(),
            'predicted_prob': result['predicted_prob'].squeeze(),
            'px': result['px'].squeeze()
        })

        if df.empty:
            print(f"Warning: Empty dataframe at index {i}")
            continue
        df_list.append(df)
    except Exception as e:
        print(f"Error at index {i}: {e}")

full_df = pd.concat(df_list, ignore_index=True)
print(full_df.shape)
print(full_df.head())


In [None]:
full_df

In [None]:
def bin_analysis(factor_name, weekly_dataframes, target_col, num_bins=5):
    bin_returns = [0.0] * num_bins
    bin_counts = [0] * num_bins

    for df in weekly_dataframes:
        if factor_name not in df.columns or target_col not in df.columns:
            continue

        sub_df = df.select([factor_name, target_col]).drop_nulls().to_pandas()
        if len(sub_df) < num_bins:
            continue

        sub_df["bin"] = pd.qcut(sub_df[factor_name], q=num_bins, labels=False, duplicates="drop")
        for i in range(num_bins):
            group = sub_df[sub_df["bin"] == i]
            if not group.empty:
                bin_returns[i] += group[target_col].mean()
                bin_counts[i] += 1

    avg_returns = [r / c if c > 0 else 0 for r, c in zip(bin_returns, bin_counts)]
    return avg_returns


In [None]:
import pandas as pd
from scipy.stats import spearmanr

def calc_ic_per_factor(weekly_dataframes, feature_cols, target_col):
    ic_records = []

    for feature in tqdm(feature_cols, unit="factor"):
        
        ic_list = []
        for df in weekly_dataframes:
            if feature not in df.columns or target_col not in df.columns:
                continue

            sub_df = df.select([feature, target_col]).drop_nulls().to_pandas()
            if len(sub_df) < 5:
                continue

            rank_ic, _ = spearmanr(sub_df[feature], sub_df[target_col])
            if pd.notna(rank_ic):
                ic_list.append(rank_ic)

        if ic_list:
            ic_mean = sum(ic_list) / len(ic_list)
            ic_std = pd.Series(ic_list).std()
            ic_ir = ic_mean / ic_std if ic_std > 1e-6 else 0
            ic_records.append({
                'factor': feature,
                'IC Mean': ic_mean,
                'IC Std': ic_std,
                'IC IR': ic_ir
            })

    ic_df = pd.DataFrame(ic_records).sort_values(by='IC IR', ascending=False)
    return ic_df


In [None]:
ic_df = calc_ic_per_factor(weekly_dataframes, feature_cols, target_col)
print("Top10 IC 因子：")
print(ic_df.head(10))

top_factors = ic_df.head(10)['factor'].tolist()
for factor in top_factors:
    returns = bin_analysis(factor, weekly_dataframes, target_col, num_bins=50)
    print(f"Factor: {factor}")


In [None]:
for _, row in ic_df.iterrows():
    factor = row['factor']
    ic = row['IC Mean']
    returns = bin_analysis(factor, weekly_dataframes, target_col, num_bins=50)
    print(f"Factor: {factor}, IC: {ic:.4f}")


In [None]:
print(ic_df)
tail_factors = ic_df.tail(10)['factor'].tolist()
for factor in tail_factors:
    returns = bin_analysis(factor, weekly_dataframes, target_col, num_bins=50)
    print(f"Factor: {factor}")


In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(5, 2, figsize=(16, 20))
axes = axes.flatten()

for i, factor in enumerate(top_factors):
    returns = bin_analysis(factor, weekly_dataframes, target_col, num_bins=50)
    axes[i].bar(range(1, 1 + len(returns)), returns)
    axes[i].set_title(f"Factor: {factor}")
    axes[i].set_xlabel("Bin")
    axes[i].set_ylabel("Mean Return")

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(5, 2, figsize=(16, 20))
axes = axes.flatten()

for i, factor in enumerate(tail_factors):
    returns = bin_analysis(factor, weekly_dataframes, target_col, num_bins=50)
    axes[i].bar(range(1, 1 + len(returns)), returns)
    axes[i].set_title(f"Factor: {factor}")
    axes[i].set_xlabel("Bin")
    axes[i].set_ylabel("Mean Return")

plt.tight_layout()
plt.show()


In [None]:
print("总样本数：", len(full_df))


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# 假设你已有 df，包含 symbol_enc、timestamp、true_label、predicted_prob、px、position（自行添加）
# 转换 timestamp 为时间格式（如需要）
df = full_df
df["timestamp"] = pd.to_datetime(full_df["timestamp"], unit="ms")

# 遍历每个币种
symbols = df["symbol_enc"].unique()
fig, axs = plt.subplots(len(symbols), 1, figsize=(14, 3 * len(symbols)), sharex=True)

if len(symbols) == 1:
    axs = [axs]

for i, sym in enumerate(symbols):
    sym_df = df[df["symbol_enc"] == sym].copy()
    
    ax = axs[i]
    sym_str = id_to_symbol[int(sym)]
    
    ax.set_title(f"{sym_str}")
    
    # 主轴: 价格
    ax.plot(sym_df["timestamp"], sym_df["px"], label="Price", color="black")
    ax.set_ylabel("Price", color="black")
    ax.tick_params(axis='y', labelcolor='black')

    # 第二轴: label
    ax2 = ax.twinx()
    ax2.plot(sym_df["timestamp"], sym_df["true_label"], label="Label", color="blue", alpha=0.6)
    ax2.set_ylabel("Label", color="blue")
    ax2.tick_params(axis='y', labelcolor='blue')

    # 第三轴: predicted prob
    ax3 = ax.twinx()
    ax3.spines.right.set_position(("axes", 1.1))  # 偏移右边
    ax3.plot(sym_df["timestamp"], sym_df["predicted_prob"], label="Pred Prob", color="orange", alpha=0.6)
    ax3.set_ylabel("Predicted", color="orange")
    ax3.tick_params(axis='y', labelcolor='orange')

    # 第四轴: position（如果你有这个字段）
    if "position" in sym_df.columns:
        ax4 = ax.twinx()
        ax4.spines.right.set_position(("axes", 1.2))  # 更偏移右边
        ax4.plot(sym_df["timestamp"], sym_df["position"], label="Position", color="green", alpha=0.6)
        ax4.set_ylabel("Position", color="green")
        ax4.tick_params(axis='y', labelcolor='green')

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr

ic_list = []
time_list = []

# 计算每个时间截面的 IC
for ts, group in tqdm(full_df.groupby('timestamp'), total=full_df['timestamp'].nunique(), desc="Calculating IC"):
    if len(group) < 2:
        continue
    if group['predicted_prob'].nunique() < 2 or group['true_label'].nunique() < 2:
        continue
    ic, _ = spearmanr(group['predicted_prob'], group['true_label'])
    ic_list.append(ic)
    time_list.append(ts)

# 转为 np.array 和 datetime 格式（如需要）
ic_array = np.array(ic_list)
time_array = np.array(time_list)

# 排序时间（可选，保险做法）
sorted_idx = np.argsort(time_array)
time_array = time_array[sorted_idx]
ic_array = ic_array[sorted_idx]

# 计算累计 IC（cum IC）
cum_ic = np.cumsum(ic_array)

# 计算 IR（信息比率 = 平均IC / IC标准差）
ir = np.mean(ic_array) / np.std(ic_array)

# 打印信息
print(f"平均IC: {np.mean(ic_array):.4f}")
print(f"平均IC std: {np.std(ic_array):.4f}")

print(f"信息比率 IR: {ir:.4f}")

# 每隔500点采样一次
step = 500
time_array_sampled = time_array[::step]
ic_array_sampled = ic_array[::step]
cum_ic_sampled = cum_ic[::step]

# 绘图
plt.figure(figsize=(14, 6))

# 子图1: IC 时间序列图（采样后）
plt.subplot(2, 1, 1)
plt.plot(time_array_sampled, ic_array_sampled, marker='o', label='IC')
plt.axhline(0, color='gray', linestyle='--')
plt.title("Information Coefficient (IC) over Time (sampled every 500 points)")
plt.ylabel("IC")
plt.legend()

# 子图2: 累积 IC 图（采样后）
plt.subplot(2, 1, 2)
plt.plot(time_array_sampled, cum_ic_sampled, color='orange', label='Cumulative IC')
plt.title(f"Cumulative IC (IR={ir:.4f})")
plt.xlabel("Time")
plt.ylabel("Cumulative IC")
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm # 用于显示进度条

# --- 假设 full_df 已经包含以下列 ---
# - 'timestamp': 原始时间戳（例如毫秒），唯一且递增
# - 'symbol_enc': 币种编码
# - 'px': 当期价格
# - 'predicted_label': 模型预测的标签（分类或排名），越大表示越好

# --- 配置参数 ---
N_INTERVAL = N # 调仓周期，每 N_INTERVAL 个时间戳调仓一次

# N_INTERVAL = 1000 # 调仓周期，每 N_INTERVAL 个时间戳调仓一次

# --- 数据预处理 ---
# 1. 转换时间戳为 datetime 类型，方便处理和绘图
bt_df = full_df
bt_df['dt'] = pd.to_datetime(full_df['timestamp'], unit='ms')

# 2. 确保数据按时间戳和币种排序，这是后续分组和shift操作的基础
bt_df = bt_df.sort_values(['dt', 'symbol_enc']).reset_index(drop=True)

# 3. 获取所有唯一的、排序后的时间戳列表
timestamps_sorted = bt_df['dt'].drop_duplicates().sort_values().to_list()

# 4. 确定所有调仓时间点
rebalance_times = timestamps_sorted[::N_INTERVAL]

# --- 识别多空信号（在每个调仓时间点）---
# 记录每个调仓时间点应持有的多空币种
rebalance_signals = {}
for t in tqdm(rebalance_times, desc="Identifying Rebalance Signals"):
    # 筛选出当前调仓时间点 t 的所有币种数据（截面数据）
    current_snapshot = bt_df[bt_df['dt'] == t]
    if current_snapshot.empty:
        continue

    # 找出 predicted_label 最高（最好）和最低（最差）的币种
    # 使用 .idxmax() 和 .idxmin() 找到索引，再用 .loc[] 提取 symbol_enc
    long_symbol = current_snapshot.loc[current_snapshot['predicted_prob'].idxmax(), 'symbol_enc']
    short_symbol = current_snapshot.loc[current_snapshot['predicted_prob'].idxmin(), 'symbol_enc']

    rebalance_signals[t] = {'long': long_symbol, 'short': short_symbol}

# --- 构建调仓周期内的持仓和计算周期收益 ---
# 存储每个调仓周期的策略收益
period_returns = []

# 初始化当前持仓，确保从第一个调仓点开始生效
current_long_symbol = None
current_short_symbol = None

# 遍历每个调仓周期
for i in tqdm(range(len(rebalance_times)), desc="Calculating Period Returns"):
    start_time = rebalance_times[i]
    # 确定当前调仓周期结束时间
    end_time = rebalance_times[i+1] if i + 1 < len(rebalance_times) else timestamps_sorted[-1]

    # 从 rebalance_signals 获取当前周期的多空币种
    if start_time in rebalance_signals:
        current_long_symbol = rebalance_signals[start_time]['long']
        current_short_symbol = rebalance_signals[start_time]['short']
    else:
        # 如果当前调仓点没有信号（不应发生），则沿用上一个周期的头寸或保持空仓
        # 这里为了简化，假设如果有信号就会找到，没有则保持上一个有效头寸
        # 如果需要严格空仓，可以在这里设置 current_long_symbol = None, current_short_symbol = None
        pass

    # 如果没有有效头寸，跳过此周期（收益为0）
    if current_long_symbol is None or current_short_symbol is None:
        period_returns.append({'dt': start_time, 'strategy_log_return': 0.0})
        continue

    # 筛选出当前调仓周期内的所有数据
    # 包括起始时间点（进行调仓），但不包括结束时间点（结算）
    period_data = bt_df[(bt_df['dt'] >= start_time) & (bt_df['dt'] <= end_time)].copy()

    # 获取多头和空头币种在该周期开始和结束时的价格
    long_start_px = period_data[(period_data['dt'] == start_time) & (period_data['symbol_enc'] == current_long_symbol)]['px'].iloc[0]
    long_end_px = period_data[(period_data['dt'] == end_time) & (period_data['symbol_enc'] == current_long_symbol)]['px'].iloc[0]

    short_start_px = period_data[(period_data['dt'] == start_time) & (period_data['dt'] <= end_time) & (period_data['symbol_enc'] == current_short_symbol)]['px'].iloc[0]
    short_end_px = period_data[(period_data['dt'] == end_time) & (period_data['dt'] <= end_time) & (period_data['symbol_enc'] == current_short_symbol)]['px'].iloc[0]

    # 计算多头和空头在该周期内的对数收益率
    long_log_ret = np.log(long_end_px) - np.log(long_start_px)
    short_log_ret = np.log(short_end_px) - np.log(short_start_px)

    # 计算策略在该周期内的总对数收益（多头收益 + 空头收益的绝对值）
    # 假设各持仓权重相等，所以是 (多头收益 - 空头收益) / 2
    # 或者说，做多一个，做空一个，组合总收益
    strategy_period_log_return = (long_log_ret - short_log_ret) / 2 # 平均对冲策略

    period_returns.append({
        'dt': start_time,
        'strategy_log_return': strategy_period_log_return
    })

# 将周期收益转换为 DataFrame
strategy_returns_df = pd.DataFrame(period_returns).set_index('dt')
strategy_returns_series = strategy_returns_df['strategy_log_return']

# --- 绩效指标函数 ---
def perf_stats(return_series, periods_per_year):
    """
    计算并返回策略的绩效统计数据。
    return_series: 每个周期的对数收益率序列。
    periods_per_year: 一年内有多少个这样的周期（用于年化）。
    """
    if return_series.empty:
        return {
            'Cumulative Return': np.nan, 'Annualized Return': np.nan,
            'Annualized Volatility': np.nan, 'Sharpe Ratio': np.nan, 'Max Drawdown': np.nan
        }

    cum_ret = return_series.cumsum().apply(np.exp) # 对数收益累加后转回普通收益
    total_return = cum_ret.iloc[-1] - 1 # 累计普通收益

    # 年化收益率 (几何平均)
    num_periods = len(return_series)
    if num_periods > 0:
        ann_return = (cum_ret.iloc[-1])**(periods_per_year / num_periods) - 1
    else:
        ann_return = np.nan

    ann_vol = return_series.std() * np.sqrt(periods_per_year) # 年化波动率
    sharpe = ann_return / ann_vol if ann_vol > 0 else np.nan

    # 最大回撤 (基于普通收益)
    running_max = cum_ret.cummax()
    drawdown = (cum_ret - running_max) / running_max
    max_dd = drawdown.min()
    return {
        'Cumulative Return': total_return,
        'Annualized Return': ann_return,
        'Annualized Volatility': ann_vol,
        'Sharpe Ratio': sharpe,
        'Max Drawdown': max_dd
    }

# --- 计算和展示绩效 ---
# 假设每个时间戳是10分钟，N_INTERVAL=1000，则一个周期是 1000 * 10分钟 = 10000 分钟
# 一年大约有 52560 个10分钟的间隔 (365天 * 24小时 * 6个10分钟/小时)
# 则一年大约有 52560 / 1000 = 52.56 个 N_INTERVAL 周期
periods_per_year_for_annualization = (365 * 24 * 60) / (N_INTERVAL * 10) # 10分钟一档
# 或者更直接： (总时间戳数 / N_INTERVAL) / 总年数

stats = perf_stats(strategy_returns_series, periods_per_year_for_annualization)
print("\n--- Strategy Performance Statistics ---")
print(pd.Series(stats))

# --- 绘制累计收益曲线 ---
plt.figure(figsize=(12, 6))
# 对数收益累加，然后取指数，得到累计普通收益曲线
strategy_returns_series.cumsum().apply(np.exp).plot()
plt.title(f"Long-Short Strategy Cumulative Return (Rebalance every {N_INTERVAL} bars)")
plt.xlabel("Rebalance Timestamp")
plt.ylabel("Cumulative Return")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns # 用于绘制热力图

# --- 假设 full_df 已经包含以下列 ---
# - 'timestamp'（时间戳，可转为 datetime）
# - 'symbol_enc'（币种编码）
# - 'px'（当期价格）
# - 'predicted_prob'（模型预测的概率，越大表示越好）

# --- 配置参数 ---
# 注意：你的原始代码中 bt_df = full_df，这里假设 full_df 已经被传入或定义
# 如果 full_df 是一个未定义的变量，你需要在这里定义或加载它
# 例如：bt_df = pd.read_csv('your_data.csv')
bt_df = full_df

# 这里N是一个未定义的变量，你需要给它一个具体的值，例如 100
# N = 100
N_INTERVAL = N * 1 # 调仓周期，每 N_INTERVAL 个时间戳调仓一次
TRANSACTION_COST_PER_TRADE = 0.0007 # 每单位交易的费用，这里是一个更通用的值

# 新增：单币种止损参数
STOP_LOSS_PERCENT = 0.15 # 单币种止损百分比，例如 0.15 表示 15% 的亏损止损

# ======================================================================
# ===== 权重计算函数 =====
# ======================================================================

n_beta = 5.0
def dollar_neutralize(weights):
    """
    保证多空资金平衡：多头资金=空头资金=0.5
    """
    long_sum = weights[weights > 0].sum()
    short_sum = -weights[weights < 0].sum()
    if long_sum == 0 or short_sum == 0:
        return weights / np.sum(np.abs(weights))
    w = weights.copy()
    w[w > 0] /= long_sum
    w[w < 0] /= short_sum
    return w / 2  # 多头合计=0.5，空头合计=-0.5

def exp_weights(scores, beta=5.0):
    """
    根据预测分数计算指数加权，并进行资金中性化
    """
    pos = np.exp(beta * np.clip(scores, 0, None))
    neg = -np.exp(beta * np.clip(-scores, 0, None))
    w = pos + neg
    return dollar_neutralize(w)

# --- 数据预处理 ---
bt_df['dt'] = pd.to_datetime(bt_df['timestamp'], unit='ms')
bt_df = bt_df.sort_values(['dt', 'symbol_enc']).reset_index(drop=True)
timestamps_sorted = bt_df['dt'].drop_duplicates().sort_values().to_list()
rebalance_times = timestamps_sorted[::N_INTERVAL]

# --- 识别多空信号（在每个调仓时间点）---
rebalance_weights = {}
for t in tqdm(rebalance_times, desc="Calculating Rebalance Weights"):
    current_snapshot = bt_df[bt_df['dt'] == t].copy()
    if current_snapshot.empty:
        rebalance_weights[t] = {}
        continue

    try:
        current_snapshot.set_index('symbol_enc', inplace=True)
        # 核心改动：使用 exp_weights
        weights = exp_weights(current_snapshot['predicted_prob'], beta=n_beta)
        rebalance_weights[t] = weights.to_dict()
    except Exception as e:
        print(f"Error calculating weights at {t}: {e}")
        rebalance_weights[t] = {}

# --- 构建调仓周期内的持仓和计算周期收益 ---
period_results = []
all_positions_by_period = []

current_weights = {} # {symbol: weight}

for i in tqdm(range(len(rebalance_times)), desc="Calculating Period Returns"):
    start_time = rebalance_times[i]
    end_time = rebalance_times[i+1] if i + 1 < len(rebalance_times) else timestamps_sorted[-1]

    # 获取当前周期的数据快照
    period_data = bt_df[(bt_df['dt'] >= start_time) & (bt_df['dt'] <= end_time)].copy()
    if period_data.empty:
        continue

    # 获取当前时间点（周期开始）的价格快照
    current_moment_prices = bt_df[bt_df['dt'] == start_time].set_index('symbol_enc')['px'].to_dict()

    # --- Step 1: 检查并处理单币种止损及联动调整 ---
    current_period_transaction_cost = 0.0
    long_positions_to_close = {}
    short_positions_to_close = {}

    # 遍历现有持仓检查止损
    for symbol, weight in list(current_weights.items()):
        open_px = current_moment_prices.get(symbol)
        if open_px is None or weight == 0:
            continue

        symbol_data = period_data[period_data['symbol_enc'] == symbol]
        if symbol_data.empty:
            continue
        
        # 简化处理：使用周期内最高/最低价判断是否触及止损
        min_px = symbol_data['px'].min()
        max_px = symbol_data['px'].max()

        triggered_stop_loss = False
        if weight > 0 and (open_px - min_px) / open_px >= STOP_LOSS_PERCENT: # 多头止损
            triggered_stop_loss = True
        elif weight < 0 and (max_px - open_px) / open_px >= STOP_LOSS_PERCENT: # 空头止损
            triggered_stop_loss = True
        
        if triggered_stop_loss:
            print(f"Stop-loss triggered for {symbol} ({'LONG' if weight > 0 else 'SHORT'}) at {start_time.strftime('%Y-%m-%d %H:%M')}.")
            
            if weight > 0:
                long_positions_to_close[symbol] = weight
            else:
                short_positions_to_close[symbol] = weight
            
            # 直接平仓，并将平仓费用计入
            current_weights[symbol] = 0.0
            current_period_transaction_cost += abs(weight) * TRANSACTION_COST_PER_TRADE

    # --- 联动调整剩余仓位以恢复资金中性 ---
    
    # 止损多头仓位，按比例减小所有空头仓位
    if long_positions_to_close:
        total_long_loss_weight = sum(long_positions_to_close.values())
        total_short_weight = -sum(w for w in current_weights.values() if w < 0)
        
        if total_short_weight > 0:
            adjustment_ratio = total_long_loss_weight / total_short_weight
            print(f"  -> Linkage: Adjusting short positions by ratio {adjustment_ratio:.2%}.")
            for symbol, weight in list(current_weights.items()):
                if weight < 0:
                    new_weight = weight * (1 - adjustment_ratio)
                    current_period_transaction_cost += abs(new_weight - weight) * TRANSACTION_COST_PER_TRADE
                    current_weights[symbol] = new_weight

    # 止损空头仓位，按比例减小所有多头仓位
    if short_positions_to_close:
        total_short_loss_weight = -sum(short_positions_to_close.values())
        total_long_weight = sum(w for w in current_weights.values() if w > 0)
        
        if total_long_weight > 0:
            adjustment_ratio = total_short_loss_weight / total_long_weight
            print(f"  -> Linkage: Adjusting long positions by ratio {adjustment_ratio:.2%}.")
            for symbol, weight in list(current_weights.items()):
                if weight > 0:
                    new_weight = weight * (1 - adjustment_ratio)
                    current_period_transaction_cost += abs(new_weight - weight) * TRANSACTION_COST_PER_TRADE
                    current_weights[symbol] = new_weight

    # --- Step 2: 根据调仓信号更新持仓并计算交易费用 ---
    # 获取本周期的目标权重
    target_weights = rebalance_weights.get(start_time, {})
    
    # 计算从当前持仓（已经过止损调整）到目标持仓的交易成本
    all_symbols = set(current_weights.keys()) | set(target_weights.keys())
    for symbol in all_symbols:
        old_weight = current_weights.get(symbol, 0)
        new_weight = target_weights.get(symbol, 0)
        current_period_transaction_cost += abs(new_weight - old_weight) * TRANSACTION_COST_PER_TRADE
    
    # 更新持仓权重
    current_weights = target_weights.copy()
    
    # --- Step 3: 计算本周期的策略收益 (毛收益) ---
    gross_period_log_return = 0.0
    if not current_weights:
        gross_period_log_return = 0.0
    else:
        end_prices = period_data[period_data['dt'] == end_time].set_index('symbol_enc')['px'].to_dict()
        start_prices = period_data[period_data['dt'] == start_time].set_index('symbol_enc')['px'].to_dict()
        
        for symbol, weight in current_weights.items():
            start_px = start_prices.get(symbol)
            end_px = end_prices.get(symbol)
            if start_px and end_px:
                log_return_per_symbol = np.log(end_px) - np.log(start_px)
                gross_period_log_return += weight * log_return_per_symbol
            
    net_period_log_return = gross_period_log_return - current_period_transaction_cost

    # --- 记录当前调仓周期所有品种的持仓状态 (用于热力图) ---
    all_unique_symbols = bt_df['symbol_enc'].unique()
    current_period_positions_for_heatmap = {symbol: 0.0 for symbol in all_unique_symbols}
    for symbol, weight in current_weights.items():
        current_period_positions_for_heatmap[symbol] = weight

    all_positions_by_period.append({'dt': start_time, **current_period_positions_for_heatmap})

    period_results.append({
        'dt': start_time,
        'gross_strategy_log_return': gross_period_log_return,
        'transaction_cost': current_period_transaction_cost, # 本周期总费用
        'net_strategy_log_return': net_period_log_return,
        'num_long_positions': sum(1 for w in current_weights.values() if w > 0),
        'num_short_positions': sum(1 for w in current_weights.values() if w < 0)
    })

# 将周期结果转换为 DataFrame
strategy_results_df = pd.DataFrame(period_results).set_index('dt')

# --- 准备持仓热力图数据 ---
positions_heatmap_df = pd.DataFrame(all_positions_by_period).set_index('dt')
positions_heatmap_df = positions_heatmap_df[sorted(positions_heatmap_df.columns)]


# --- 绩效指标函数 (与之前相同) ---
def perf_stats(return_series, periods_per_year):
    if return_series.empty:
        return {
            'Cumulative Return': np.nan, 'Annualized Return': np.nan,
            'Annualized Volatility': np.nan, 'Sharpe Ratio': np.nan, 'Max Drawdown': np.nan
        }

    cum_ret = return_series.cumsum().apply(np.exp)
    total_return = cum_ret.iloc[-1] - 1

    num_periods = len(return_series)
    if num_periods > 0:
        ann_return = (cum_ret.iloc[-1])**(periods_per_year / num_periods) - 1
    else:
        ann_return = np.nan

    ann_vol = return_series.std() * np.sqrt(periods_per_year)
    sharpe = ann_return / ann_vol if ann_vol > 0 else np.nan

    running_max = cum_ret.cummax()
    drawdown = (cum_ret - running_max) / running_max
    max_dd = drawdown.min()
    return {
        'Cumulative Return': total_return,
        'Annualized Return': ann_return,
        'Annualized Volatility': ann_vol,
        'Sharpe Ratio': sharpe,
        'Max Drawdown': max_dd
    }

# --- 计算和展示绩效 ---
periods_per_year_for_annualization = (365 * 24 * 60) / (N_INTERVAL * 10) # 假设10ms一个bar

print("\n--- Strategy Performance Statistics (Gross) ---")
gross_stats = perf_stats(strategy_results_df['gross_strategy_log_return'], periods_per_year_for_annualization)
print(pd.Series(gross_stats))

print("\n--- Strategy Performance Statistics (Net of Costs) ---")
net_stats = perf_stats(strategy_results_df['net_strategy_log_return'], periods_per_year_for_annualization)
print(pd.Series(net_stats))

print(f"\nTotal Transaction Cost (Sum of individual costs): {strategy_results_df['transaction_cost'].sum():.6f}")

# --- 绘制图表 (使用低饱和度配色) ---
fig, axes = plt.subplots(4, 1, figsize=(16, 22), sharex=False, gridspec_kw={'height_ratios': [0.35, 0.2, 0.2, 0.25]})

# 定义低饱和度颜色
COLOR_GROSS_RETURN = sns.color_palette("Paired")[1]
COLOR_NET_RETURN = sns.color_palette("Paired")[3]
COLOR_TRANSACTION_COST = sns.color_palette("Paired")[5]
COLOR_LONG_POSITIONS = sns.color_palette("Paired")[7]
COLOR_SHORT_POSITIONS = sns.color_palette("Paired")[9]

# 热力图颜色映射 (低饱和度 RdBu)
HEATMAP_CMAP = sns.diverging_palette(240, 10, as_cmap=True, s=70, l=60, sep=1)

# 设置图表背景风格 (可选，例如设置为白色背景)
plt.style.use('seaborn-v0_8-whitegrid')

# 子图1: 累计收益 (毛收益 vs. 净收益)
strategy_results_df['gross_strategy_log_return'].cumsum().apply(np.exp).plot(ax=axes[0], label='Cumulative Gross Return', color=COLOR_GROSS_RETURN)
strategy_results_df['net_strategy_log_return'].cumsum().apply(np.exp).plot(ax=axes[0], label='Cumulative Net Return', color=COLOR_NET_RETURN)
axes[0].set_title(f"Cumulative Strategy Returns (Rebalance every {N_INTERVAL} bars, Exponential Weighted)")
axes[0].set_ylabel("Cumulative Return")
axes[0].legend()
axes[0].grid(True, linestyle=':', alpha=0.7)
axes[0].set_xlabel("")

# 子图2: 累计手续费消耗
strategy_results_df['transaction_cost'].cumsum().plot(ax=axes[1], label='Cumulative Transaction Cost', color=COLOR_TRANSACTION_COST)
axes[1].set_title("Cumulative Transaction Cost Over Time")
axes[1].set_ylabel("Total Cost")
axes[1].legend()
axes[1].grid(True, linestyle=':', alpha=0.7)
axes[1].set_xlabel("")

# 子图3: 持仓数量变化 (保留，作为快速概览)
strategy_results_df['num_long_positions'].plot(ax=axes[2], label='Number of Long Positions', color=COLOR_LONG_POSITIONS, drawstyle='steps-post')
strategy_results_df['num_short_positions'].plot(ax=axes[2], label='Number of Short Positions', color=COLOR_SHORT_POSITIONS, drawstyle='steps-post')
axes[2].set_title("Number of Long and Short Positions Over Time (Overview)")
axes[2].set_ylabel("Count")
axes[2].set_xlabel("")
axes[2].legend()
axes[2].grid(True, linestyle=':', alpha=0.7)
axes[2].set_ylim(bottom=0)

# 子图4: 详细持仓热力图
if not positions_heatmap_df.empty:
    # 限制热力图只显示部分品种，避免图表过大
    unique_symbols = positions_heatmap_df.columns
    if len(unique_symbols) > 50:
        # 只显示前50个品种
        positions_heatmap_df_display = positions_heatmap_df.iloc[:, :50]
    else:
        positions_heatmap_df_display = positions_heatmap_df

    sns.heatmap(
        positions_heatmap_df_display.T,
        cmap=HEATMAP_CMAP,
        cbar_kws={'label': 'Position Weight'},
        ax=axes[3],
        yticklabels=True,
        xticklabels=True,
        linewidths=0.5,
        linecolor='lightgray'
    )
    axes[3].set_title("Detailed Position Weight Heatmap (Per Symbol)")
    axes[3].set_xlabel("Rebalance Timestamp")
    axes[3].set_ylabel("Symbol")
    
    # 调整x轴刻度标签，避免重叠
    num_ticks = 10 
    if len(positions_heatmap_df_display.index) > num_ticks:
        tick_interval = len(positions_heatmap_df_display.index) // num_ticks
        axes[3].set_xticks(np.arange(0, len(positions_heatmap_df_display.index), tick_interval))
        axes[3].set_xticklabels(positions_heatmap_df_display.index[::tick_interval].strftime('%Y-%m-%d %H:%M'))
    else:
        axes[3].set_xticklabels(positions_heatmap_df_display.index.strftime('%Y-%m-%d %H:%M'))

plt.tight_layout()
plt.show()

In [None]:
import os
import json
import pickle
from datetime import datetime
from sklearn.preprocessing import LabelEncoder

def save_tabnet_checkpoint(
    model,
    base_save_dir: str,
    model_params: dict,
    feature_names: list[str],
    training_meta: dict,
    unique_id: str = None,
):
    if unique_id is None:
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M")
        save_dir = os.path.join(base_save_dir, f"tabnet_{timestamp}")
    else:
        save_dir = os.path.join(base_save_dir, unique_id)

    os.makedirs(save_dir, exist_ok=True)

    # 保存模型
    model_path = os.path.join(save_dir, "tabnet_model")
    model.save_model(model_path)

    # 保存模型参数和元信息
    config_path = os.path.join(save_dir, "model_metadata.json")
    with open(config_path, "w") as f:
        json.dump({
            "model_params": model_params,
            "meta_info": training_meta,
        }, f, indent=4)

    # 保存辅助对象
    aux_path = os.path.join(save_dir, "auxiliary.pkl")
    with open(aux_path, "wb") as f:
        pickle.dump({
            "feature_names": feature_names,
        }, f)

    print(f"✅ 模型和元信息已保存到: {save_dir}")


In [None]:
print(feature_cols, cat_idxs, cat_dims)
N_INTERVAL = N

In [None]:
print(symbol_to_id.keys())
for sym, df in symbol_to_id.items():
    print(sym)
    print(df)

In [None]:
def convert_dict_np_to_builtin(d):
    return {str(k): int(v) for k, v in d.items()}

def convert_dict_keys_to_str_and_values_to_builtin(d):
    new_d = {}
    for k, v in d.items():
        # key 可能是 np.int64
        new_key = int(k) if isinstance(k, (np.integer,)) else str(k)
        # value 可能是 np.str_
        new_val = str(v) if isinstance(v, (np.str_,)) else v
        new_d[new_key] = new_val
    return new_d

In [None]:
a  =convert_dict_np_to_builtin(symbol_to_id)
print(a["BNBUSDT"])

In [None]:
b = convert_dict_keys_to_str_and_values_to_builtin(id_to_symbol)
print(b)

In [None]:
   
# save_tabnet_checkpoint(
#     model=tabnet,
#     base_save_dir="./saved_models/tabnet_crosec",
#     model_params=params,
#     feature_names=feature_cols,
#     training_meta={
#         "symbol_to_id": convert_dict_np_to_builtin(symbol_to_id),
#         "id_to_symbol": convert_dict_keys_to_str_and_values_to_builtin(id_to_symbol),
#         "train_n_week": n_train_weeks,
#         "fit_params": init_fit_params,
#         "label_window": N,
#         "saved_timestamp": str(pd.Timestamp.now()),
#         "feat_cal_window": int(feat_cal_window),
#         "feat_norm_window": feat_norm_window,
#         "feat_norm_rolling_mean_window": feat_norm_rolling_mean_window,
#         "n_interval": N_INTERVAL,
#         "n_beta": n_beta,
#         "sl_percent": STOP_LOSS_PERCENT,
#     },
# )