# TabM增强版：性早熟预测模型

**使用TabM进行全方位优化**

本笔记本展示了TabM的多种增强技术：
- 基础TabM模型（PiecewiseLinearEmbeddings）
- 超参数优化（Optuna HPO）
- 不同架构变体（tabm / tabm-mini）
- 独立批次训练策略
- 不同数值嵌入方式对比
- SHAP可解释性分析

## 1. 导入必要的库

In [1]:
import os
import math
import random
from copy import deepcopy

import pandas as pd
import numpy as np
import warnings
import joblib

warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score,
    roc_curve,
    f1_score,
    accuracy_score,
    precision_score,
    recall_score,
    confusion_matrix,
)
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import QuantileTransformer

# TabM相关库
import tabm
import rtdl_num_embeddings
import torch
import torch.nn as nn
import optuna

import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False

print(f"TabM版本: {tabm.__version__ if hasattr(tabm, '__version__') else 'N/A'}")
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU设备: {torch.cuda.get_device_name(0)}")
print("所有库导入完成")

TabM版本: 0.0.3
PyTorch版本: 2.9.1+cu130
CUDA可用: True
GPU设备: NVIDIA GeForce RTX 3080 Laptop GPU
所有库导入完成


## 2. 设置路径和参数

In [2]:
os.makedirs("./output", exist_ok=True)
os.makedirs("./output/models", exist_ok=True)
os.makedirs("./output/tabm_enhanced", exist_ok=True)
os.makedirs("./output/tabm_enhanced/models", exist_ok=True)

RANDOM_SEED = 825
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"随机种子设置: {RANDOM_SEED}")
print(f"计算设备: {DEVICE}")
print(f"输出目录: ./output/tabm_enhanced/")

随机种子设置: 825
计算设备: cuda
输出目录: ./output/tabm_enhanced/


## 3. 读取数据

In [3]:
normal_data = pd.read_csv("./input/性早熟数据激发试验正常组_new.csv")
disease_data = pd.read_csv("./input/激发试验确诊性早熟组数据_new.csv")

normal_data["group"] = "N"
disease_data["group"] = "Y"

print(f"正常组: {normal_data.shape[0]} 行, 早熟组: {disease_data.shape[0]} 行")

正常组: 8970 行, 早熟组: 10654 行


## 4. 数据类型处理和合并

In [4]:
data = pd.concat([normal_data, disease_data], axis=0, ignore_index=True)
data["group"] = data["group"].astype("category")
print(f"合并后数据: {data.shape[0]} 行 x {data.shape[1]} 列")
print(f"分组统计:\n{data['group'].value_counts()}")

合并后数据: 19624 行 x 40 列
分组统计:
group
Y    10654
N     8970
Name: count, dtype: int64


## 5. 划分训练集和验证集

In [5]:
train_data, validation_data = train_test_split(
    data, test_size=0.3, stratify=data["group"], random_state=RANDOM_SEED
)

print(f"训练集: {train_data.shape[0]} 行, 验证集: {validation_data.shape[0]} 行")

训练集: 13736 行, 验证集: 5888 行


## 6. 特征工程

In [6]:
exclude_cols = ["group", "患者编号", "Unnamed: 0"]
feature_cols = [col for col in train_data.columns if col not in exclude_cols]

X_train = train_data[feature_cols].copy()
y_train = train_data["group"].copy()
X_validation = validation_data[feature_cols].copy()
y_validation = validation_data["group"].copy()

y_train_binary = (y_train == "Y").astype(int)
y_validation_binary = (y_validation == "Y").astype(int)

print(f"使用 {len(feature_cols)} 个特征")
print(f"训练集正负样本: {y_train_binary.value_counts().to_dict()}")
print(f"验证集正负样本: {y_validation_binary.value_counts().to_dict()}")

使用 38 个特征
训练集正负样本: {1: 7457, 0: 6279}
验证集正负样本: {1: 3197, 0: 2691}


## 7. 数据类型转换

In [7]:
print("数据类型转换中...")

for col in feature_cols:
    X_train[col] = pd.to_numeric(X_train[col], errors="coerce")
    X_validation[col] = pd.to_numeric(X_validation[col], errors="coerce")

print(f"转换完成 - 训练集: {X_train.dtypes.value_counts().to_dict()}")

数据类型转换中...
转换完成 - 训练集: {dtype('float64'): 38}


## 8. 数据预处理（缺失值填充）

In [8]:
print("数据预处理开始...")

# 使用中位数填充缺失值
imputer = SimpleImputer(strategy="median")

X_train_processed = imputer.fit_transform(X_train)
X_validation_processed = imputer.transform(X_validation)

# 移除全为NaN的特征
valid_features = ~np.isnan(X_train_processed).all(axis=0)
X_train_processed = X_train_processed[:, valid_features]
X_validation_processed = X_validation_processed[:, valid_features]

# 更新特征列表
feature_cols_processed = [
    col for col, valid in zip(feature_cols, valid_features) if valid
]

print(f"预处理完成！")
print(f"  原始特征数: {X_train.shape[1]}")
print(f"  处理后特征数: {X_train_processed.shape[1]}")
print(f"  训练集样本: {X_train_processed.shape[0]}")
print(f"  验证集样本: {X_validation_processed.shape[0]}")
print(f"  缺失值: {np.isnan(X_train_processed).sum()} (应为0)")

数据预处理开始...
预处理完成！
  原始特征数: 38
  处理后特征数: 38
  训练集样本: 13736
  验证集样本: 5888
  缺失值: 0 (应为0)


## 9. QuantileTransformer预处理（TabM推荐）

In [9]:
print("应用QuantileTransformer...")

# 添加微小噪声以改善QuantileTransformer输出
noise = (
    np.random.default_rng(RANDOM_SEED)
    .normal(0.0, 1e-5, X_train_processed.shape)
    .astype(X_train_processed.dtype)
)

quantile_transformer = QuantileTransformer(
    n_quantiles=max(min(len(X_train_processed) // 30, 1000), 10),
    output_distribution="normal",
    subsample=10**9,
).fit(X_train_processed + noise)

X_train_transformed = quantile_transformer.transform(X_train_processed)
X_validation_transformed = quantile_transformer.transform(X_validation_processed)

# 转换为PyTorch张量
X_train_tensor = torch.tensor(X_train_transformed, dtype=torch.float32).to(DEVICE)
y_train_tensor = torch.tensor(y_train_binary.values, dtype=torch.long).to(DEVICE)
X_val_tensor = torch.tensor(X_validation_transformed, dtype=torch.float32).to(DEVICE)
y_val_tensor = torch.tensor(y_validation_binary.values, dtype=torch.long).to(DEVICE)

print(f"训练集张量: {X_train_tensor.shape}")
print(f"验证集张量: {X_val_tensor.shape}")
print(f"QuantileTransformer预处理完成！")

应用QuantileTransformer...
训练集张量: torch.Size([13736, 38])
验证集张量: torch.Size([5888, 38])
QuantileTransformer预处理完成！


---
# 模型训练与优化

## 10. 定义训练和评估函数

In [10]:
def train_tabm_model(
    model,
    X_train,
    y_train,
    X_val,
    y_val,
    n_epochs=500,
    batch_size=256,
    lr=2e-3,
    weight_decay=3e-4,
    patience=32,
    gradient_clipping_norm=1.0,
    share_training_batches=True,
    verbose=True,
):
    """训练TabM模型的通用函数"""
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    amp_enabled = torch.cuda.is_available()

    best_f1 = 0
    best_auc = 0
    best_epoch = 0
    best_state = None
    remaining_patience = patience

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0

        if share_training_batches:
            batches = torch.randperm(len(X_train), device=DEVICE).split(batch_size)
        else:
            # k个独立批次序列
            batches = (
                torch.rand((len(X_train), model.backbone.k), device=DEVICE)
                .argsort(dim=0)
                .split(batch_size, dim=0)
            )

        for batch_idx in batches:
            optimizer.zero_grad()

            with torch.autocast(
                device_type="cuda", enabled=amp_enabled, dtype=amp_dtype
            ):
                logits = model(X_train[batch_idx], None)
                y_pred = logits.flatten(0, 1)

                if share_training_batches:
                    y_true = y_train[batch_idx].repeat_interleave(model.backbone.k)
                else:
                    y_true = y_train[batch_idx].flatten(0, 1)

                loss = nn.functional.cross_entropy(y_pred, y_true)

            loss.backward()

            if gradient_clipping_norm is not None:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), gradient_clipping_norm
                )

            optimizer.step()
            total_loss += loss.item()

        # 验证
        model.eval()
        with torch.no_grad():
            val_logits = model(X_val, None)
            val_proba = (
                torch.softmax(val_logits, dim=-1).mean(dim=1)[:, 1].cpu().numpy()
            )
            val_pred = (val_proba >= 0.5).astype(int)

            y_val_np = y_val.cpu().numpy()
            auc = roc_auc_score(y_val_np, val_proba)
            f1 = f1_score(y_val_np, val_pred)

        improved = f1 > best_f1

        if verbose and (epoch % 50 == 0 or improved):
            print(
                f"Epoch {epoch:3d}: Loss={total_loss/len(batches):.4f}, F1={f1:.4f}, AUC={auc:.4f}{' *' if improved else ''}"
            )

        if improved:
            best_f1, best_auc, best_epoch = f1, auc, epoch
            best_state = deepcopy(model.state_dict())
            remaining_patience = patience
        else:
            remaining_patience -= 1

        if remaining_patience < 0:
            if verbose:
                print(f"Early stopping at epoch {epoch}")
            break

    # 恢复最佳模型
    if best_state is not None:
        model.load_state_dict(best_state)

    return {
        "best_f1": best_f1,
        "best_auc": best_auc,
        "best_epoch": best_epoch,
    }


def evaluate_model(model, X, y):
    """评估模型性能"""
    model.eval()
    with torch.no_grad():
        logits = model(X, None)
        proba = torch.softmax(logits, dim=-1).mean(dim=1)[:, 1].cpu().numpy()
        pred = (proba >= 0.5).astype(int)

        y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y

        return {
            "auc": roc_auc_score(y_np, proba),
            "f1": f1_score(y_np, pred),
            "accuracy": accuracy_score(y_np, pred),
            "precision": precision_score(y_np, pred),
            "recall": recall_score(y_np, pred),
            "y_pred": pred,
            "y_proba": proba,
        }


print("训练和评估函数定义完成！")

训练和评估函数定义完成！


## 11. 模型1：基础TabM（PiecewiseLinearEmbeddings）

In [11]:
print("=" * 70)
print("训练基础TabM模型（PiecewiseLinearEmbeddings）")
print("=" * 70)

# 创建PiecewiseLinear嵌入
num_embeddings_basic = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
    rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=48),
    d_embedding=16,
    activation=False,
    version="B",
)

# 创建TabM模型
tabm_basic = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_basic,
).to(DEVICE)

print(f"模型参数量: {sum(p.numel() for p in tabm_basic.parameters()):,}")
print(f"集成数量k: {tabm_basic.backbone.k}")
print("\n开始训练...\n")

result_basic = train_tabm_model(
    tabm_basic,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=500,
    patience=32,
)

# 最终评估
metrics_basic = evaluate_model(tabm_basic, X_val_tensor, y_val_tensor)

print(f"\n基础TabM性能:")
print(f"  AUC: {metrics_basic['auc']:.4f}")
print(f"  F1:  {metrics_basic['f1']:.4f}")
print(f"  ACC: {metrics_basic['accuracy']:.4f}")

# 保存模型
torch.save(tabm_basic.state_dict(), "./output/tabm_enhanced/models/tabm_basic.pt")
joblib.dump(
    {"imputer": imputer, "quantile_transformer": quantile_transformer},
    "./output/tabm_enhanced/models/tabm_basic_preprocessors.pkl",
)
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_basic.pt")

训练基础TabM模型（PiecewiseLinearEmbeddings）
模型参数量: 738,048
集成数量k: 32

开始训练...

Epoch   0: Loss=0.4282, F1=0.8902, AUC=0.9496 *
Epoch   1: Loss=0.2916, F1=0.8968, AUC=0.9589 *
Epoch   3: Loss=0.2458, F1=0.9034, AUC=0.9632 *
Epoch   4: Loss=0.2361, F1=0.9065, AUC=0.9640 *
Epoch   6: Loss=0.2236, F1=0.9098, AUC=0.9661 *


KeyboardInterrupt: 

## 12. 模型2：TabM-Mini架构

In [None]:
print("=" * 70)
print("训练TabM-Mini架构")
print("=" * 70)

# 创建PiecewiseLinear嵌入
num_embeddings_mini = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
    rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=48),
    d_embedding=16,
    activation=False,
    version="B",
)

# 创建TabM-Mini模型
tabm_mini = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_mini,
    arch_type="tabm-mini",  # Mini架构，更强正则化
).to(DEVICE)

print(f"模型参数量: {sum(p.numel() for p in tabm_mini.parameters()):,}")
print(f"集成数量k: {tabm_mini.backbone.k}")
print("\n开始训练...\n")

result_mini = train_tabm_model(
    tabm_mini,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=500,
    patience=32,
)

# 最终评估
metrics_mini = evaluate_model(tabm_mini, X_val_tensor, y_val_tensor)

print(f"\nTabM-Mini性能:")
print(f"  AUC: {metrics_mini['auc']:.4f}")
print(f"  F1:  {metrics_mini['f1']:.4f}")
print(f"  ACC: {metrics_mini['accuracy']:.4f}")

# 保存模型
torch.save(tabm_mini.state_dict(), "./output/tabm_enhanced/models/tabm_mini.pt")
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_mini.pt")

## 13. 模型3：独立批次训练

In [None]:
print("=" * 70)
print("训练TabM（独立批次策略）")
print("=" * 70)
print("独立批次训练：k个子模型在不同批次上训练，增加多样性")

# 创建PiecewiseLinear嵌入
num_embeddings_indep = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
    rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=48),
    d_embedding=16,
    activation=False,
    version="B",
)

# 创建TabM模型
tabm_indep = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_indep,
).to(DEVICE)

print(f"模型参数量: {sum(p.numel() for p in tabm_indep.parameters()):,}")
print("\n开始训练...\n")

result_indep = train_tabm_model(
    tabm_indep,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=500,
    patience=32,
    share_training_batches=False,  # 独立批次
)

# 最终评估
metrics_indep = evaluate_model(tabm_indep, X_val_tensor, y_val_tensor)

print(f"\nTabM（独立批次）性能:")
print(f"  AUC: {metrics_indep['auc']:.4f}")
print(f"  F1:  {metrics_indep['f1']:.4f}")
print(f"  ACC: {metrics_indep['accuracy']:.4f}")

# 保存模型
torch.save(tabm_indep.state_dict(), "./output/tabm_enhanced/models/tabm_indep.pt")
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_indep.pt")

## 14. 模型4：PeriodicEmbeddings

In [None]:
print("=" * 70)
print("训练TabM（PeriodicEmbeddings）")
print("=" * 70)

# 创建Periodic嵌入
num_embeddings_periodic = rtdl_num_embeddings.PeriodicEmbeddings(
    n_features=X_train_tensor.shape[1],
    d_embedding=16,
    lite=False,
)

# 创建TabM模型
tabm_periodic = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_periodic,
).to(DEVICE)

print(f"模型参数量: {sum(p.numel() for p in tabm_periodic.parameters()):,}")
print("\n开始训练...\n")

result_periodic = train_tabm_model(
    tabm_periodic,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=500,
    patience=32,
)

# 最终评估
metrics_periodic = evaluate_model(tabm_periodic, X_val_tensor, y_val_tensor)

print(f"\nTabM（PeriodicEmbeddings）性能:")
print(f"  AUC: {metrics_periodic['auc']:.4f}")
print(f"  F1:  {metrics_periodic['f1']:.4f}")
print(f"  ACC: {metrics_periodic['accuracy']:.4f}")

# 保存模型
torch.save(tabm_periodic.state_dict(), "./output/tabm_enhanced/models/tabm_periodic.pt")
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_periodic.pt")

## 15. 模型5：超参数优化（Optuna HPO）

In [None]:
print("=" * 70)
print("TabM超参数优化（Optuna）")
print("=" * 70)


def objective(trial):
    """Optuna目标函数"""
    # 超参数搜索空间
    n_blocks = trial.suggest_int("n_blocks", 1, 4)
    d_block = trial.suggest_int("d_block", 64, 512, step=64)
    n_bins = trial.suggest_int("n_bins", 16, 96, step=16)
    d_embedding = trial.suggest_int("d_embedding", 8, 32, step=4)
    lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-1, log=True)
    dropout = trial.suggest_float("dropout", 0.0, 0.3, step=0.05)

    try:
        # 创建嵌入
        num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
            rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=n_bins),
            d_embedding=d_embedding,
            activation=False,
            version="B",
        )

        # 创建模型
        model = tabm.TabM.make(
            n_num_features=X_train_tensor.shape[1],
            cat_cardinalities=[],
            d_out=2,
            num_embeddings=num_embeddings,
            n_blocks=n_blocks,
            d_block=d_block,
            dropout=dropout,
        ).to(DEVICE)

        # 训练
        result = train_tabm_model(
            model,
            X_train_tensor,
            y_train_tensor,
            X_val_tensor,
            y_val_tensor,
            n_epochs=200,
            patience=20,
            lr=lr,
            weight_decay=weight_decay,
            verbose=False,
        )

        return result["best_f1"]

    except Exception as e:
        print(f"Trial failed: {e}")
        return 0.0


# 运行优化
study = optuna.create_study(direction="maximize", study_name="tabm_hpo")
study.optimize(objective, n_trials=50, show_progress_bar=True)

print(f"\n最佳参数:")
for key, value in study.best_params.items():
    print(f"  {key}: {value}")
print(f"\n最佳F1: {study.best_value:.4f}")

In [None]:
# 使用最佳参数训练最终模型
print("\n使用最佳参数训练最终模型...")

best_params = study.best_params

num_embeddings_hpo = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
    rtdl_num_embeddings.compute_bins(X_train_tensor, n_bins=best_params["n_bins"]),
    d_embedding=best_params["d_embedding"],
    activation=False,
    version="B",
)

tabm_hpo = tabm.TabM.make(
    n_num_features=X_train_tensor.shape[1],
    cat_cardinalities=[],
    d_out=2,
    num_embeddings=num_embeddings_hpo,
    n_blocks=best_params["n_blocks"],
    d_block=best_params["d_block"],
    dropout=best_params["dropout"],
).to(DEVICE)

result_hpo = train_tabm_model(
    tabm_hpo,
    X_train_tensor,
    y_train_tensor,
    X_val_tensor,
    y_val_tensor,
    n_epochs=500,
    patience=32,
    lr=best_params["lr"],
    weight_decay=best_params["weight_decay"],
)

# 最终评估
metrics_hpo = evaluate_model(tabm_hpo, X_val_tensor, y_val_tensor)

print(f"\nTabM-HPO性能:")
print(f"  AUC: {metrics_hpo['auc']:.4f}")
print(f"  F1:  {metrics_hpo['f1']:.4f}")
print(f"  ACC: {metrics_hpo['accuracy']:.4f}")

# 保存模型和参数
torch.save(tabm_hpo.state_dict(), "./output/tabm_enhanced/models/tabm_hpo.pt")
joblib.dump(
    {
        "best_params": best_params,
        "imputer": imputer,
        "quantile_transformer": quantile_transformer,
    },
    "./output/tabm_enhanced/models/tabm_hpo_config.pkl",
)
print(f"\n模型已保存: ./output/tabm_enhanced/models/tabm_hpo.pt")
print("=" * 70)

## 16. 性能对比汇总

In [None]:
print("=" * 70)
print("TabM变体性能对比")
print("=" * 70)

results_df = pd.DataFrame(
    {
        "模型": [
            "TabM-Basic",
            "TabM-Mini",
            "TabM-IndepBatch",
            "TabM-Periodic",
            "TabM-HPO",
        ],
        "AUC": [
            metrics_basic["auc"],
            metrics_mini["auc"],
            metrics_indep["auc"],
            metrics_periodic["auc"],
            metrics_hpo["auc"],
        ],
        "F1": [
            metrics_basic["f1"],
            metrics_mini["f1"],
            metrics_indep["f1"],
            metrics_periodic["f1"],
            metrics_hpo["f1"],
        ],
        "Accuracy": [
            metrics_basic["accuracy"],
            metrics_mini["accuracy"],
            metrics_indep["accuracy"],
            metrics_periodic["accuracy"],
            metrics_hpo["accuracy"],
        ],
    }
)

results_df = results_df.sort_values("AUC", ascending=False)
print(results_df.to_string(index=False))

# 保存结果
results_df.to_csv("./output/tabm_enhanced/性能对比.csv", index=False)
print(f"\n结果已保存: ./output/tabm_enhanced/性能对比.csv")

## 17. ROC曲线对比

In [None]:
plt.figure(figsize=(10, 8))

models_info = [
    ("TabM-Basic", metrics_basic["y_proba"], metrics_basic["auc"]),
    ("TabM-Mini", metrics_mini["y_proba"], metrics_mini["auc"]),
    ("TabM-IndepBatch", metrics_indep["y_proba"], metrics_indep["auc"]),
    ("TabM-Periodic", metrics_periodic["y_proba"], metrics_periodic["auc"]),
    ("TabM-HPO", metrics_hpo["y_proba"], metrics_hpo["auc"]),
]

y_val_np = y_validation_binary.values

for model_name, y_proba, auc_score in models_info:
    fpr, tpr, _ = roc_curve(y_val_np, y_proba)
    linewidth = 3 if model_name == "TabM-HPO" else 2
    plt.plot(
        fpr, tpr, label=f"{model_name} (AUC = {auc_score:.4f})", linewidth=linewidth
    )

plt.plot([0, 1], [0, 1], "k--", linewidth=1, label="随机猜测")
plt.xlabel("假阳性率 (1-特异度)", fontsize=12)
plt.ylabel("真阳性率 (灵敏度)", fontsize=12)
plt.title("TabM变体ROC曲线对比", fontsize=14, fontweight="bold")
plt.legend(loc="lower right", fontsize=10)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("./output/tabm_enhanced/ROC曲线对比.png", dpi=300, bbox_inches="tight")
plt.savefig("./output/tabm_enhanced/ROC曲线对比.pdf", bbox_inches="tight")
plt.show()
print("ROC曲线已保存")

## 18. 选择最佳模型并复制到标准位置

In [None]:
print("=" * 70)
print("选择最佳模型")
print("=" * 70)

# 找出最佳模型
all_metrics = {
    "basic": metrics_basic,
    "mini": metrics_mini,
    "indep": metrics_indep,
    "periodic": metrics_periodic,
    "hpo": metrics_hpo,
}

best_model_name = max(all_metrics, key=lambda x: all_metrics[x]["auc"])
best_metrics = all_metrics[best_model_name]

print(f"最佳模型: TabM-{best_model_name}")
print(f"  AUC: {best_metrics['auc']:.4f}")
print(f"  F1:  {best_metrics['f1']:.4f}")

# 复制最佳模型到标准位置
import shutil

src_model = f"./output/tabm_enhanced/models/tabm_{best_model_name}.pt"
dst_model = "./output/models/tabm_best.pt"
shutil.copy(src_model, dst_model)

# 保存预处理器
joblib.dump(
    {
        "imputer": imputer,
        "quantile_transformer": quantile_transformer,
        "best_model_name": best_model_name,
        "metrics": best_metrics,
    },
    "./output/models/tabm_preprocessors.pkl",
)

print(f"\n最佳模型已复制到: {dst_model}")
print("=" * 70)

## 19. 显存使用统计

In [None]:
if torch.cuda.is_available():
    print(f"显存峰值使用: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
    print(f"当前显存使用: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

    # 清理显存
    torch.cuda.empty_cache()
    print("显存已清理")