# TabM增强版：性早熟预测模型

**使用TabM进行全方位优化**

- 基础TabM模型（PiecewiseLinearEmbeddings）
- 超参数优化（Optuna HPO）
- 不同架构变体（tabm / tabm-mini）
- 独立批次训练策略
- 不同数值嵌入方式对比

## 1. 导入必要的库

In [21]:
import os
import math
import random
from copy import deepcopy

import pandas as pd
import numpy as np
import warnings
import joblib

warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score,
    roc_curve,
    f1_score,
    accuracy_score,
    precision_score,
    recall_score,
    confusion_matrix,
)
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.ensemble import RandomForestRegressor

# TabM相关库
import tabm
import rtdl_num_embeddings
import torch
import torch.nn as nn
import optuna

import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False

print(f"TabM版本: {tabm.__version__ if hasattr(tabm, '__version__') else 'N/A'}")
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU设备: {torch.cuda.get_device_name(0)}")
print("所有库导入完成")

TabM版本: 0.0.3
PyTorch版本: 2.9.1+cu130
CUDA可用: True
GPU设备: NVIDIA GeForce RTX 3080 Laptop GPU
所有库导入完成


## 2. 设置路径和参数

In [22]:
os.makedirs("./output", exist_ok=True)
os.makedirs("./output/models", exist_ok=True)
os.makedirs("./output/tabm_enhanced", exist_ok=True)
os.makedirs("./output/tabm_enhanced/models", exist_ok=True)

RANDOM_SEED = 825
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"随机种子设置: {RANDOM_SEED}")
print(f"计算设备: {DEVICE}")
print(f"输出目录: ./output/tabm_enhanced/")

随机种子设置: 825
计算设备: cuda
输出目录: ./output/tabm_enhanced/


## 3. 读取数据

In [23]:
normal_data = pd.read_csv("./input/性早熟数据激发试验正常组_new.csv")
disease_data = pd.read_csv("./input/激发试验确诊性早熟组数据_new.csv")

normal_data["group"] = "N"
disease_data["group"] = "Y"

print(f"正常组: {normal_data.shape[0]} 行, 早熟组: {disease_data.shape[0]} 行")

正常组: 299 行, 早熟组: 364 行


## 4. 数据类型处理和合并

In [24]:
# 分组填补：先对每组分别填补，再合并
exclude_cols = ["group", "患者编号", "Unnamed: 0"]
feature_cols = [col for col in normal_data.columns if col not in exclude_cols]

# 定义分类特征和数值特征
categorical_info = {
    "Tanner分期": (1, 5),
    "乳晕色素沉着": (0, 2),
    # "乳核": (0, 1),  # 过滤乳腺缺失样本后只有1个唯一值，排除
    "有无阴毛": (0, 1),
    "有无腋毛": (0, 1),
}
categorical_cols = [c for c in categorical_info.keys() if c in feature_cols]
numerical_cols = [c for c in feature_cols if c not in categorical_cols and c != "乳核"]

print(f"使用 {len(categorical_cols) + len(numerical_cols)} 个特征（排除乳核）")
print(f"分类特征 ({len(categorical_cols)}个): {categorical_cols}")
print(f"数值特征 ({len(numerical_cols)}个)")

# ===== 分组填补 =====
from sklearn.ensemble import RandomForestClassifier

print("\n使用MissForest方法进行分组填补...")

# 正常组填补器
cat_imputer_normal = IterativeImputer(
    estimator=RandomForestClassifier(
        n_estimators=10, max_depth=10, n_jobs=-1, random_state=RANDOM_SEED
    ),
    max_iter=10,
    random_state=RANDOM_SEED,
    verbose=0,
)
num_imputer_normal = IterativeImputer(
    estimator=RandomForestRegressor(
        n_estimators=10, max_depth=10, n_jobs=-1, random_state=RANDOM_SEED
    ),
    max_iter=10,
    random_state=RANDOM_SEED,
    verbose=0,
)

# 性早熟组填补器
cat_imputer_disease = IterativeImputer(
    estimator=RandomForestClassifier(
        n_estimators=10, max_depth=10, n_jobs=-1, random_state=RANDOM_SEED
    ),
    max_iter=10,
    random_state=RANDOM_SEED,
    verbose=0,
)
num_imputer_disease = IterativeImputer(
    estimator=RandomForestRegressor(
        n_estimators=10, max_depth=10, n_jobs=-1, random_state=RANDOM_SEED
    ),
    max_iter=10,
    random_state=RANDOM_SEED,
    verbose=0,
)

# 正常组填补
normal_cat = (
    cat_imputer_normal.fit_transform(normal_data[categorical_cols])
    if categorical_cols
    else None
)
normal_num = num_imputer_normal.fit_transform(normal_data[numerical_cols])

# 性早熟组填补
disease_cat = (
    cat_imputer_disease.fit_transform(disease_data[categorical_cols])
    if categorical_cols
    else None
)
disease_num = num_imputer_disease.fit_transform(disease_data[numerical_cols])

# 裁剪分类特征到有效范围
if categorical_cols:
    for i, col in enumerate(categorical_cols):
        min_val, max_val = categorical_info[col]
        normal_cat[:, i] = normal_cat[:, i].clip(min_val, max_val)
        disease_cat[:, i] = disease_cat[:, i].clip(min_val, max_val)

# 组装填补后的数据
if categorical_cols:
    normal_imputed = pd.DataFrame(
        np.hstack([normal_cat, normal_num]), columns=categorical_cols + numerical_cols
    )
    disease_imputed = pd.DataFrame(
        np.hstack([disease_cat, disease_num]), columns=categorical_cols + numerical_cols
    )
else:
    normal_imputed = pd.DataFrame(normal_num, columns=numerical_cols)
    disease_imputed = pd.DataFrame(disease_num, columns=numerical_cols)

print(f"正常组填补完成: {normal_imputed.shape}")
print(f"性早熟组填补完成: {disease_imputed.shape}")

# 添加标签列
normal_imputed["group"] = "N"
disease_imputed["group"] = "Y"

# 合并数据
data = pd.concat([normal_imputed, disease_imputed], axis=0, ignore_index=True)
data["group"] = data["group"].astype("category")
print(f"\n合并后数据: {data.shape[0]} 行 x {data.shape[1]} 列")
print(f"分组统计:\n{data['group'].value_counts()}")

使用 38 个特征（排除乳核）
分类特征 (4个): ['Tanner分期', '乳晕色素沉着', '有无阴毛', '有无腋毛']
数值特征 (34个)

使用MissForest方法进行分组填补...
正常组填补完成: (299, 38)
性早熟组填补完成: (364, 38)

合并后数据: 663 行 x 39 列
分组统计:
group
Y    364
N    299
Name: count, dtype: int64


## 5. 划分训练集和验证集

In [25]:
train_data, validation_data = train_test_split(
    data, test_size=0.3, stratify=data["group"], random_state=RANDOM_SEED
)

print(f"训练集: {train_data.shape[0]} 行, 验证集: {validation_data.shape[0]} 行")

训练集: 464 行, 验证集: 199 行


## 6. 特征工程

In [26]:
exclude_cols = ["group", "患者编号", "Unnamed: 0"]
feature_cols = [col for col in train_data.columns if col not in exclude_cols]

X_train = train_data[feature_cols].copy()
y_train = train_data["group"].copy()
X_validation = validation_data[feature_cols].copy()
y_validation = validation_data["group"].copy()

y_train_binary = (y_train == "Y").astype(int)
y_validation_binary = (y_validation == "Y").astype(int)

print(f"使用 {len(feature_cols)} 个特征（含缺失指示器）")
print(f"训练集正负样本: {y_train_binary.value_counts().to_dict()}")
print(f"验证集正负样本: {y_validation_binary.value_counts().to_dict()}")

使用 38 个特征（含缺失指示器）
训练集正负样本: {1: 255, 0: 209}
验证集正负样本: {1: 109, 0: 90}


## 7. 数据预处理（缺失值填充）

In [27]:
# 数据已在分组填补阶段完成，这里只做标准化
from sklearn.preprocessing import StandardScaler

# 分离特征（按原始顺序）
feature_cols_ordered = categorical_cols + numerical_cols

X_train_features = train_data[feature_cols_ordered].values
X_validation_features = validation_data[feature_cols_ordered].values

# 分类特征保持不变，数值特征标准化
cat_count = len(categorical_cols)
num_count = len(numerical_cols)

if cat_count > 0:
    X_train_cat = X_train_features[:, :cat_count]
    X_validation_cat = X_validation_features[:, :cat_count]
    X_train_num = X_train_features[:, cat_count:]
    X_validation_num = X_validation_features[:, cat_count:]
else:
    X_train_cat = None
    X_validation_cat = None
    X_train_num = X_train_features
    X_validation_num = X_validation_features

# 数值特征标准化
scaler = StandardScaler()
X_train_num_scaled = scaler.fit_transform(X_train_num)
X_validation_num_scaled = scaler.transform(X_validation_num)
print("数值特征已标准化 (StandardScaler)")

# 合并分类和数值特征
if cat_count > 0:
    X_train_processed = np.hstack([X_train_cat, X_train_num_scaled])
    X_validation_processed = np.hstack([X_validation_cat, X_validation_num_scaled])
else:
    X_train_processed = X_train_num_scaled
    X_validation_processed = X_validation_num_scaled

feature_cols_processed = feature_cols_ordered

print(f"\n预处理完成！")
print(f"  特征数: {X_train_processed.shape[1]}")
print(f"  训练集样本: {X_train_processed.shape[0]}")
print(f"  验证集样本: {X_validation_processed.shape[0]}")

数值特征已标准化 (StandardScaler)

预处理完成！
  特征数: 38
  训练集样本: 464
  验证集样本: 199


## 8. 数据转换为PyTorch张量

In [28]:
print("转换数据为PyTorch张量...")

X_train_tensor = torch.tensor(X_train_processed, dtype=torch.float32).to(DEVICE)
y_train_tensor = torch.tensor(y_train_binary.values, dtype=torch.long).to(DEVICE)
X_val_tensor = torch.tensor(X_validation_processed, dtype=torch.float32).to(DEVICE)
y_val_tensor = torch.tensor(y_validation_binary.values, dtype=torch.long).to(DEVICE)

print(f"训练集张量: {X_train_tensor.shape}")
print(f"验证集张量: {X_val_tensor.shape}")
print(f"数据转换完成！")

转换数据为PyTorch张量...
训练集张量: torch.Size([464, 38])
验证集张量: torch.Size([199, 38])
数据转换完成！


---
# 模型训练与优化

## 9. 定义训练和评估函数

In [29]:
def train_tabm_model(
    model,
    X_train,
    y_train,
    X_val,
    y_val,
    n_epochs=500,
    batch_size=256,
    lr=2e-3,
    weight_decay=3e-4,
    patience=32,
    gradient_clipping_norm=1.0,
    share_training_batches=True,
    verbose=True,
):
    """训练TabM模型的通用函数"""
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    amp_enabled = torch.cuda.is_available()

    best_f1 = 0
    best_auc = 0
    best_epoch = 0
    best_state = None
    remaining_patience = patience

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0

        if share_training_batches:
            batches = torch.randperm(len(X_train), device=DEVICE).split(batch_size)
        else:
            # k个独立批次序列
            batches = (
                torch.rand((len(X_train), model.backbone.k), device=DEVICE)
                .argsort(dim=0)
                .split(batch_size, dim=0)
            )

        for batch_idx in batches:
            optimizer.zero_grad()

            with torch.autocast(
                device_type="cuda", enabled=amp_enabled, dtype=amp_dtype
            ):
                logits = model(X_train[batch_idx], None)
                y_pred = logits.flatten(0, 1)

                if share_training_batches:
                    y_true = y_train[batch_idx].repeat_interleave(model.backbone.k)
                else:
                    y_true = y_train[batch_idx].flatten(0, 1)

                loss = nn.functional.cross_entropy(y_pred, y_true)

            loss.backward()

            if gradient_clipping_norm is not None:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), gradient_clipping_norm
                )

            optimizer.step()
            total_loss += loss.item()

        # 验证
        model.eval()
        with torch.no_grad():
            val_logits = model(X_val, None)
            val_proba = (
                torch.softmax(val_logits, dim=-1).mean(dim=1)[:, 1].cpu().numpy()
            )
            val_pred = (val_proba >= 0.5).astype(int)

            y_val_np = y_val.cpu().numpy()
            auc = roc_auc_score(y_val_np, val_proba)
            f1 = f1_score(y_val_np, val_pred)

        improved = f1 > best_f1

        if verbose and (epoch % 50 == 0 or improved):
            print(
                f"Epoch {epoch:3d}: Loss={total_loss/len(batches):.4f}, F1={f1:.4f}, AUC={auc:.4f}{' *' if improved else ''}"
            )

        if improved:
            best_f1, best_auc, best_epoch = f1, auc, epoch
            best_state = deepcopy(model.state_dict())
            remaining_patience = patience
        else:
            remaining_patience -= 1

        if remaining_patience < 0:
            if verbose:
                print(f"Early stopping at epoch {epoch}")
            break

    # 恢复最佳模型
    if best_state is not None:
        model.load_state_dict(best_state)

    return {
        "best_f1": best_f1,
        "best_auc": best_auc,
        "best_epoch": best_epoch,
    }


def evaluate_model(model, X, y):
    """评估模型性能"""
    model.eval()
    with torch.no_grad():
        logits = model(X, None)
        proba = torch.softmax(logits, dim=-1).mean(dim=1)[:, 1].cpu().numpy()
        pred = (proba >= 0.5).astype(int)

        y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y

        return {
            "auc": roc_auc_score(y_np, proba),
            "f1": f1_score(y_np, pred),
            "accuracy": accuracy_score(y_np, pred),
            "precision": precision_score(y_np, pred),
            "recall": recall_score(y_np, pred),
            "y_pred": pred,
            "y_proba": proba,
        }


print("训练和评估函数定义完成！")

训练和评估函数定义完成！


## 10. 模型1：基础TabM

In [30]:
print("=" * 70)
print("训练基础TabM模型")
print("=" * 70)

# 创建PiecewiseLinear嵌入
num_embeddings_basic = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
    rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=48),
    d_embedding=16,
    activation=False,
    version="B",
)

# 创建TabM模型
tabm_basic = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_basic,
).to(DEVICE)

print(f"模型参数量: {sum(p.numel() for p in tabm_basic.parameters()):,}")
print(f"集成数量k: {tabm_basic.backbone.k}")
print("\n开始训练\n")

n_epochs = 500
result_basic = train_tabm_model(
    tabm_basic,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=n_epochs,
    patience=n_epochs,
)

# 最终评估
metrics_basic = evaluate_model(tabm_basic, X_val_tensor, y_val_tensor)

print(f"\n基础TabM性能:")
print(f"  AUC: {metrics_basic['auc']:.4f}")
print(f"  F1:  {metrics_basic['f1']:.4f}")
print(f"  ACC: {metrics_basic['accuracy']:.4f}")

# 保存模型和预处理器
torch.save(tabm_basic.state_dict(), "./output/tabm_enhanced/models/tabm_basic.pt")
joblib.dump(
    {
        "cat_imputer_normal": cat_imputer_normal,
        "cat_imputer_disease": cat_imputer_disease,
        "num_imputer_normal": num_imputer_normal,
        "num_imputer_disease": num_imputer_disease,
        "scaler": scaler,
        "categorical_cols": categorical_cols,
        "numerical_cols": numerical_cols,
    },
    "./output/tabm_enhanced/models/tabm_basic_preprocessors.pkl",
)
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_basic.pt")

训练基础TabM模型
模型参数量: 738,048
集成数量k: 32

开始训练

Epoch   0: Loss=0.6807, F1=0.7114, AUC=0.8172 *
Epoch   1: Loss=0.6238, F1=0.8083, AUC=0.8345 *
Epoch   2: Loss=0.5412, F1=0.8113, AUC=0.8512 *
Epoch   3: Loss=0.4605, F1=0.8288, AUC=0.8659 *
Epoch   4: Loss=0.4186, F1=0.8357, AUC=0.8886 *
Epoch   5: Loss=0.3799, F1=0.8722, AUC=0.9062 *
Epoch   6: Loss=0.3425, F1=0.8793, AUC=0.9188 *
Epoch   7: Loss=0.3297, F1=0.8811, AUC=0.9256 *
Epoch   8: Loss=0.2835, F1=0.8974, AUC=0.9305 *
Epoch  10: Loss=0.2461, F1=0.9170, AUC=0.9408 *
Epoch  12: Loss=0.2046, F1=0.9177, AUC=0.9513 *
Epoch  32: Loss=0.0392, F1=0.9211, AUC=0.9594 *
Epoch  50: Loss=0.0048, F1=0.9211, AUC=0.9620
Epoch 100: Loss=0.0007, F1=0.9211, AUC=0.9587
Epoch 150: Loss=0.0009, F1=0.9163, AUC=0.9583
Epoch 165: Loss=0.0004, F1=0.9258, AUC=0.9574 *
Epoch 200: Loss=0.0001, F1=0.9163, AUC=0.9604
Epoch 211: Loss=0.0002, F1=0.9264, AUC=0.9594 *
Epoch 212: Loss=0.0008, F1=0.9304, AUC=0.9591 *
Epoch 250: Loss=0.0005, F1=0.9211, AUC=0.9582
Epoch 3

## 11. 模型2：TabM-Mini架构

In [None]:
print("=" * 70)
print("训练TabM-Mini架构")
print("=" * 70)

# 创建PiecewiseLinear嵌入
num_embeddings_mini = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
    rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=48),
    d_embedding=16,
    activation=False,
    version="B",
)

# 创建TabM-Mini模型
tabm_mini = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_mini,
    arch_type="tabm-mini",  # Mini架构，更强正则化
).to(DEVICE)

print(f"模型参数量: {sum(p.numel() for p in tabm_mini.parameters()):,}")
print(f"集成数量k: {tabm_mini.backbone.k}")
print("\n开始训练...\n")

n_epochs = 500
result_mini = train_tabm_model(
    tabm_mini,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=n_epochs,
    patience=n_epochs,
)

# 最终评估
metrics_mini = evaluate_model(tabm_mini, X_val_tensor, y_val_tensor)

print(f"\nTabM-Mini性能:")
print(f"  AUC: {metrics_mini['auc']:.4f}")
print(f"  F1:  {metrics_mini['f1']:.4f}")
print(f"  ACC: {metrics_mini['accuracy']:.4f}")

# 保存模型
torch.save(tabm_mini.state_dict(), "./output/tabm_enhanced/models/tabm_mini.pt")
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_mini.pt")
# TabM-Mini性能:
#   AUC: 0.8834
#   F1:  0.8391
#   ACC: 0.8137

## 12. 模型3：独立批次训练

In [None]:
print("=" * 70)
print("训练TabM（独立批次策略）")
print("=" * 70)
print("独立批次训练：k个子模型在不同批次上训练，增加多样性")

# 创建PiecewiseLinear嵌入
num_embeddings_indep = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
    rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=48),
    d_embedding=16,
    activation=False,
    version="B",
)

# 创建TabM模型
tabm_indep = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_indep,
).to(DEVICE)

print(f"模型参数量: {sum(p.numel() for p in tabm_indep.parameters()):,}")
print("\n开始训练...\n")

n_epochs = 500
result_indep = train_tabm_model(
    tabm_indep,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=n_epochs,
    patience=n_epochs,
    share_training_batches=False,
)

# 最终评估
metrics_indep = evaluate_model(tabm_indep, X_val_tensor, y_val_tensor)

print(f"\nTabM（独立批次）性能:")
print(f"  AUC: {metrics_indep['auc']:.4f}")
print(f"  F1:  {metrics_indep['f1']:.4f}")
print(f"  ACC: {metrics_indep['accuracy']:.4f}")

# 保存模型
torch.save(tabm_indep.state_dict(), "./output/tabm_enhanced/models/tabm_indep.pt")
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_indep.pt")
# TabM（独立批次）性能:
#   AUC: 0.8866
#   F1:  0.8344
#   ACC: 0.8089

## 13. 模型4：PeriodicEmbeddings

In [None]:
print("=" * 70)
print("训练TabM（PeriodicEmbeddings）")
print("=" * 70)

# 创建Periodic嵌入
num_embeddings_periodic = rtdl_num_embeddings.PeriodicEmbeddings(
    n_features=X_train_tensor.shape[1],
    d_embedding=16,
    lite=False,
)

# 创建TabM模型
tabm_periodic = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_periodic,
).to(DEVICE)

print(f"模型参数量: {sum(p.numel() for p in tabm_periodic.parameters()):,}")
print("\n开始训练...\n")

n_epochs = 500
result_periodic = train_tabm_model(
    tabm_periodic,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=n_epochs,
    patience=n_epochs,
)

# 最终评估
metrics_periodic = evaluate_model(tabm_periodic, X_val_tensor, y_val_tensor)

print(f"\nTabM（PeriodicEmbeddings）性能:")
print(f"  AUC: {metrics_periodic['auc']:.4f}")
print(f"  F1:  {metrics_periodic['f1']:.4f}")
print(f"  ACC: {metrics_periodic['accuracy']:.4f}")

# 保存模型
torch.save(tabm_periodic.state_dict(), "./output/tabm_enhanced/models/tabm_periodic.pt")
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_periodic.pt")
# TabM（PeriodicEmbeddings）性能:
#   AUC: 0.8710
#   F1:  0.8255
#   ACC: 0.7945

## 14. 模型5：超参数优化（Optuna HPO）

In [None]:
print("=" * 70)
print("TabM超参数优化（Optuna）")
print("=" * 70)

# 全局变量保存最佳模型
best_hpo_model_state = None
best_hpo_f1 = 0.0
best_hpo_config = None


def objective(trial):
    """Optuna目标函数"""
    global best_hpo_model_state, best_hpo_f1, best_hpo_config

    # 超参数搜索空间
    n_blocks = trial.suggest_int("n_blocks", 1, 4)
    d_block = trial.suggest_int("d_block", 64, 512, step=64)
    n_bins = trial.suggest_int("n_bins", 16, 96, step=16)
    d_embedding = trial.suggest_int("d_embedding", 8, 32, step=4)
    lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-1, log=True)
    dropout = trial.suggest_float("dropout", 0.0, 0.3, step=0.05)

    try:
        # 创建嵌入
        num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
            rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=n_bins),
            d_embedding=d_embedding,
            activation=False,
            version="B",
        )

        # 创建模型
        model = tabm.TabM.make(
            n_num_features=X_train_tensor.shape[1],
            cat_cardinalities=[],
            d_out=2,
            num_embeddings=num_embeddings,
            n_blocks=n_blocks,
            d_block=d_block,
            dropout=dropout,
        ).to(DEVICE)

        # 训练
        n_epochs = 200
        result = train_tabm_model(
            model,
            X_train_tensor,
            y_train_tensor,
            X_val_tensor,
            y_val_tensor,
            n_epochs=n_epochs,
            patience=50,  # 早停轮数
            lr=lr,
            weight_decay=weight_decay,
            verbose=False,
        )

        f1 = result["best_f1"]

        # 如果是新的最佳模型，保存权重
        if f1 > best_hpo_f1:
            best_hpo_f1 = f1
            best_hpo_model_state = deepcopy(model.state_dict())
            best_hpo_config = {
                "n_blocks": n_blocks,
                "d_block": d_block,
                "n_bins": n_bins,
                "d_embedding": d_embedding,
                "dropout": dropout,
            }

        return f1

    except Exception as e:
        print(f"Trial failed: {e}")
        return 0.0


# 运行优化
study = optuna.create_study(direction="maximize", study_name="tabm_hpo")
study.optimize(objective, n_trials=1000, show_progress_bar=True)

print(f"\n最佳参数:")
for key, value in study.best_params.items():
    print(f"  {key}: {value}")
print(f"\n最佳F1: {study.best_value:.4f}")

# 保存HPO过程中的最佳模型权重
torch.save(best_hpo_model_state, "./output/tabm_enhanced/models/tabm_hpo.pt")
joblib.dump(
    {
        "best_params": study.best_params,
        "best_config": best_hpo_config,
        "cat_imputer_normal": cat_imputer_normal,
        "cat_imputer_disease": cat_imputer_disease,
        "num_imputer_normal": num_imputer_normal,
        "num_imputer_disease": num_imputer_disease,
        "scaler": scaler,
        "categorical_cols": categorical_cols,
        "numerical_cols": numerical_cols,
    },
    "./output/tabm_enhanced/models/tabm_hpo_config.pkl",
)
print(f"\nHPO最佳模型已保存: ./output/tabm_enhanced/models/tabm_hpo.pt")

In [None]:
# 模型导入
print("=" * 70)
print("导入已保存的TabM模型")
print("=" * 70)

# 加载预处理器（包含scaler）
preprocessors = joblib.load(
    "./output/tabm_enhanced/models/tabm_basic_preprocessors.pkl"
)
imputer = preprocessors["imputer"]
scaler = preprocessors.get("scaler", None)
if scaler:
    print("已加载StandardScaler")

# 模型配置
model_configs = {
    "basic": {"arch_type": "tabm", "n_bins": 48, "d_embedding": 16},
    "mini": {"arch_type": "tabm-mini", "n_bins": 48, "d_embedding": 16},
    "indep": {"arch_type": "tabm", "n_bins": 48, "d_embedding": 16},
    "periodic": {"arch_type": "tabm", "use_periodic": True, "d_embedding": 16},
}

# 尝试加载HPO配置
hpo_config_path = "./output/tabm_enhanced/models/tabm_hpo_config.pkl"
if os.path.exists(hpo_config_path):
    hpo_data = joblib.load(hpo_config_path)
    best_params = hpo_data["best_params"]
    best_hpo_config = hpo_data.get("best_config", best_params)
    model_configs["hpo"] = {
        "arch_type": "tabm",
        "n_bins": best_hpo_config["n_bins"],
        "d_embedding": best_hpo_config["d_embedding"],
        "n_blocks": best_hpo_config["n_blocks"],
        "d_block": best_hpo_config["d_block"],
        "dropout": best_hpo_config["dropout"],
    }
    print(f"HPO最佳参数: {best_params}")

# 加载所有模型
loaded_models = {}
all_metrics = {}

for name, config in model_configs.items():
    model_path = f"./output/tabm_enhanced/models/tabm_{name}.pt"
    if not os.path.exists(model_path):
        print(f"  {name}: 模型文件不存在，跳过")
        continue

    try:
        # 创建嵌入层
        if config.get("use_periodic", False):
            num_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(
                n_features=X_train_tensor.shape[1],
                d_embedding=config["d_embedding"],
                lite=False,
            )
        else:
            num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
                rtdl_num_embeddings.compute_bins(
                    X_train_tensor, n_bins=config["n_bins"]
                ),
                d_embedding=config["d_embedding"],
                activation=False,
                version="B",
            )

        # 创建模型
        model_kwargs = {
            "n_num_features": X_train_tensor.shape[1],
            "cat_cardinalities": [],
            "d_out": 2,
            "num_embeddings": num_embeddings,
        }
        if config["arch_type"] == "tabm-mini":
            model_kwargs["arch_type"] = "tabm-mini"
        if "n_blocks" in config:
            model_kwargs["n_blocks"] = config["n_blocks"]
        if "d_block" in config:
            model_kwargs["d_block"] = config["d_block"]
        if "dropout" in config:
            model_kwargs["dropout"] = config["dropout"]

        model = tabm.TabM.make(**model_kwargs).to(DEVICE)
        model.load_state_dict(
            torch.load(model_path, map_location=DEVICE, weights_only=True)
        )
        model.eval()

        # 评估模型
        metrics = evaluate_model(model, X_val_tensor, y_val_tensor)

        loaded_models[name] = model
        all_metrics[name] = metrics

        print(f"  TabM-{name}: F1={metrics['f1']:.4f}, AUC={metrics['auc']:.4f} ✓")

    except Exception as e:
        print(f"  {name}: 加载失败 - {e}")

# 为后续代码准备变量
if "basic" in loaded_models:
    tabm_basic = loaded_models["basic"]
    metrics_basic = all_metrics["basic"]
if "mini" in loaded_models:
    tabm_mini = loaded_models["mini"]
    metrics_mini = all_metrics["mini"]
if "indep" in loaded_models:
    tabm_indep = loaded_models["indep"]
    metrics_indep = all_metrics["indep"]
if "periodic" in loaded_models:
    tabm_periodic = loaded_models["periodic"]
    metrics_periodic = all_metrics["periodic"]
if "hpo" in loaded_models:
    tabm_hpo = loaded_models["hpo"]
    metrics_hpo = all_metrics["hpo"]

print(f"\n成功导入 {len(loaded_models)} 个模型")
print("=" * 70)

In [None]:
print("=" * 70)
print("TabM变体性能对比")
print("=" * 70)

results_df = pd.DataFrame(
    {
        "模型": [
            "TabM-Basic",
            "TabM-Mini",
            "TabM-IndepBatch",
            "TabM-Periodic",
            "TabM-HPO",
        ],
        "AUC": [
            metrics_basic["auc"],
            metrics_mini["auc"],
            metrics_indep["auc"],
            metrics_periodic["auc"],
            metrics_hpo["auc"],
        ],
        "F1": [
            metrics_basic["f1"],
            metrics_mini["f1"],
            metrics_indep["f1"],
            metrics_periodic["f1"],
            metrics_hpo["f1"],
        ],
        "Accuracy": [
            metrics_basic["accuracy"],
            metrics_mini["accuracy"],
            metrics_indep["accuracy"],
            metrics_periodic["accuracy"],
            metrics_hpo["accuracy"],
        ],
    }
)

results_df = results_df.sort_values("AUC", ascending=False)
print(results_df.to_string(index=False))

# 保存结果
results_df.to_csv("./output/tabm_enhanced/性能对比.csv", index=False)
print(f"\n结果已保存: ./output/tabm_enhanced/性能对比.csv")

## 16. ROC曲线对比

In [None]:
plt.figure(figsize=(10, 8))

models_info = [
    ("TabM-Basic", metrics_basic["y_proba"], metrics_basic["auc"]),
    ("TabM-Mini", metrics_mini["y_proba"], metrics_mini["auc"]),
    ("TabM-IndepBatch", metrics_indep["y_proba"], metrics_indep["auc"]),
    ("TabM-Periodic", metrics_periodic["y_proba"], metrics_periodic["auc"]),
    ("TabM-HPO", metrics_hpo["y_proba"], metrics_hpo["auc"]),
]

y_val_np = y_validation_binary.values

for model_name, y_proba, auc_score in models_info:
    fpr, tpr, _ = roc_curve(y_val_np, y_proba)
    linewidth = 3 if model_name == "TabM-HPO" else 2
    plt.plot(
        fpr, tpr, label=f"{model_name} (AUC = {auc_score:.4f})", linewidth=linewidth
    )

plt.plot([0, 1], [0, 1], "k--", linewidth=1, label="随机猜测")
plt.xlabel("假阳性率 (1-特异度)", fontsize=12)
plt.ylabel("真阳性率 (灵敏度)", fontsize=12)
plt.title("TabM变体ROC曲线对比", fontsize=14, fontweight="bold")
plt.legend(loc="lower right", fontsize=10)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("./output/tabm_enhanced/ROC曲线对比.png", dpi=300, bbox_inches="tight")
plt.savefig("./output/tabm_enhanced/ROC曲线对比.pdf", bbox_inches="tight")
plt.show()
print("ROC曲线已保存")

## 17. 选择最佳模型并复制到标准位置

In [None]:
print("=" * 70)
print("选择最佳模型（以F1为标准）")
print("=" * 70)

# 找出最佳模型
all_metrics = {
    "basic": metrics_basic,
    "mini": metrics_mini,
    "indep": metrics_indep,
    "periodic": metrics_periodic,
    "hpo": metrics_hpo,
}

best_model_name = max(all_metrics, key=lambda x: all_metrics[x]["f1"])  # 以F1为标准
best_metrics = all_metrics[best_model_name]

print(f"最佳模型: TabM-{best_model_name}")
print(f"  F1:  {best_metrics['f1']:.4f}")
print(f"  AUC: {best_metrics['auc']:.4f}")

# 复制最佳模型到标准位置
import shutil

src_model = f"./output/tabm_enhanced/models/tabm_{best_model_name}.pt"
dst_model = "./output/models/tabm_best.pt"
shutil.copy(src_model, dst_model)

# 保存预处理器（包含scaler用于标准化）
joblib.dump(
    {
        "imputer": imputer,
        "scaler": scaler,
        "best_model_name": best_model_name,
        "metrics": best_metrics,
    },
    "./output/models/tabm_preprocessors.pkl",
)

print(f"\n最佳模型已复制到: {dst_model}")
print("=" * 70)