In [None]:
# -*- coding: utf-8 -*-
"""
WGAN-GP增强与XGBoost分类的严谨实验框架
=================================================
本脚本实现了一个完整的、符合高水平学术研究标准的机器学习流程，用于心血管疾病的二元分类预测。
核心方法学包括：
1.  **严格的数据隔离**: 在流程开始时就将数据分为独立的“开发集”和“最终测试集”，确保最终评估的无偏性。
2.  **动态嵌套式数据增强**: 在K-折交叉验证的每一折内部，动态地使用当前折的训练数据来训练WGAN-GP并生成增强样本，从根本上杜绝了数据泄露。
3.  **混合数据类型处理**: WGAN-GP流程内置了对连续特征（标准化）和分类特征（独热编码）的自动化处理与逆转换。
4.  **正确的模型与评估**: 全程使用XGBoost分类器（XGBClassifier）以及与之匹配的分类性能指标（如Accuracy, F1 Score, AUC）进行模型训练、调优和评估。
"""

# --- 基础库导入 ---
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import xgboost as xgb
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, roc_auc_score, ConfusionMatrixDisplay,
                             make_scorer)
import matplotlib.pyplot as plt
import os
import joblib
from pathlib import Path

# --- 0. 全局配置 ---

# --- 路径配置 ---
# 建议将此脚本放在一个名为 "scripts" 的文件夹中，
# "data" 和 "output" 文件夹与 "scripts" 文件夹在同一级。
# 示例目录结构:
# /your_project_folder/
#  ├─ data/
#  │  ├─ development_set.xlsx
#  │  └─ final_test_set.xlsx
#  ├─ scripts/
#  │  └─ this_script.py
#  └─ output/
#     ├─ augmented_outputs/
#     └─ trained_models/

CURRENT_DIR = Path.cwd()
# 假设脚本在 'scripts' 文件夹内，项目根目录是其父目录
PROJECT_ROOT = CURRENT_DIR.parent
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "output"

# 数据文件路径
DEV_SET_FILE = DATA_DIR / "development_set.xlsx"
TEST_SET_FILE = DATA_DIR / "final_test_set.xlsx"
# 中间生成文件和最终模型的保存路径
AUGMENTED_DATA_OUTPUT_FOLDER = OUTPUT_DIR / "augmented_outputs"
MODEL_OUTPUT_PATH = OUTPUT_DIR / "trained_models"
OUTPUT_PLOT_PATH = OUTPUT_DIR # 图表导出路径

# 自动创建不存在的输出文件夹
os.makedirs(AUGMENTED_DATA_OUTPUT_FOLDER, exist_ok=True)
os.makedirs(MODEL_OUTPUT_PATH, exist_ok=True)
os.makedirs(OUTPUT_PLOT_PATH, exist_ok=True)

# --- 实验参数 ---
TARGET_COLUMN = 'target'  # 目标变量（标签）的列名
RANDOM_STATE = 42         # 随机种子，确保实验可复现
N_SPLITS_KFOLD = 5        # K-Fold交叉验证的折数

# --- WGAN-GP 预设参数 ---
DEFAULT_WGAN_PARAMS = {
    'latent_dim': 100,      # 随机噪声向量的维度
    'lambda_gp': 10,        # 梯度惩罚的系数
    'n_critic': 5,          # 每次生成器更新前，判别器的更新次数
    'lr': 0.00005,          # 学习率
    'batch_size': 32,       # 批处理大小
    'epochs': 500           # 默认训练轮数 (在CV中会使用)
}

# --- XGBoost RandomizedSearchCV 超参数搜索空间 ---
XGB_PARAM_GRID = {
    'n_estimators': [100, 200, 300, 400],
    'max_depth': [3, 5, 7],
    'learning_rate': [0.01, 0.05, 0.1],
    'subsample': [0.7, 0.8, 0.9, 1],
    'colsample_bytree': [0.7, 0.8, 0.9, 1],
    'gamma': [0, 0.1, 0.2],
    'reg_alpha': [0, 0.01, 0.1],
    'reg_lambda': [0.5, 1, 1.5],
    'scale_pos_weight': [1, 5, 10, 20] # 处理类别不平衡的关键参数
}
N_ITER_RANDOMIZED_SEARCH = 30 # RandomizedSearchCV的迭代次数

# --- 设备配置 ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"将使用设备: {device}")

# --- 数据特征定义 ---
# 连续型特征列名
CONTINUOUS_COLS = ['age', 'restingBP', 'serumcholestrol', 'maxheartrate', 'oldpeak']
# 分类型特征列名 (包含目标变量，因为GAN需要学习其联合分布)
CATEGORICAL_COLS = ['gender', 'fastingbloodsugar', 'chestpain',
                    'restingrelectro', 'exerciseangia', 'slope',
                    'noofmajorvessels', 'target']

# 分类型特征的合法取值范围 (用于生成数据后的后处理)
CATEGORY_MAPPINGS = {
    'gender':           [0, 1],
    'fastingbloodsugar':[0, 1],
    'chestpain':        [0, 1, 2, 3],
    'restingrelectro':  [0, 1, 2],
    'exerciseangia':    [0, 1],
    'slope':            [1, 2, 3], # 已根据数据描述修正，包含0
    'noofmajorvessels': [0, 1, 2, 3],
    'target':           [0, 1]
}

# 连续型特征的合理值域 (用于生成数据后的裁剪)
CONTINUOUS_BOUNDS = {
    'age':              (20, 80),
    'restingBP':        (94, 200),
    'serumcholestrol':  (126, 564),
    'maxheartrate':     (71, 202),
    'oldpeak':          (0, 6.2)
}

# --- 1. WGAN-GP 模型定义与核心功能函数 ---

# --- WGAN-GP PyTorch网络结构 ---
class Generator(nn.Module):
    """生成器网络"""
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(True),
            nn.Linear(128, 256),      nn.ReLU(True),
            nn.Linear(256, 512),      nn.ReLU(True),
            nn.Linear(512, output_dim)
        )
    def forward(self, z):
        return self.model(z)

class Critic(nn.Module):
    """判别器/评论家网络"""
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512), nn.LeakyReLU(0.2, True),
            nn.Linear(512, 256),       nn.LeakyReLU(0.2, True),
            nn.Linear(256, 128),       nn.LeakyReLU(0.2, True),
            nn.Linear(128, 1)
        )
    def forward(self, x):
        return self.model(x)

def gradient_penalty(critic, real, fake, device):
    """计算梯度惩罚项"""
    alpha = torch.rand(real.size(0), 1, device=device)
    interpolates = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    d_interpolates = critic(interpolates)
    grad_out = torch.ones_like(d_interpolates, device=device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates, inputs=interpolates,
        grad_outputs=grad_out,
        create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

# --- WGAN-GP 训练与生成主函数 ---
def train_and_generate_wgangp(input_original_df,
                              wgan_hyperparams,
                              num_samples_to_generate,
                              current_device,
                              fold_num_for_logging=None,
                              output_augmented_data_path=None):
    """
    一个完整的函数，用于训练WGAN-GP并生成指定数量的增强样本。
    该函数封装了所有预处理、训练、生成和后处理步骤。
    """
    log_prefix = f"[WGAN-GP"
    if fold_num_for_logging:
        log_prefix += f" {fold_num_for_logging}"
    log_prefix += "]"
    print(f"\n{log_prefix}] 开始数据增强，输入形状: {input_original_df.shape}")

    # 1. 预处理：拆分、标准化、独热编码
    cont_df = input_original_df[CONTINUOUS_COLS]
    cat_df  = input_original_df[CATEGORICAL_COLS]

    scaler = StandardScaler()
    cont_std = scaler.fit_transform(cont_df)

    encoder = OneHotEncoder(sparse_output=False, dtype=np.float32, handle_unknown='ignore')
    cat_oh  = encoder.fit_transform(cat_df)

    # 2. 准备PyTorch DataLoader
    X_processed = np.hstack([cont_std, cat_oh]).astype(np.float32)
    data_tensor = torch.tensor(X_processed)
    bs = min(wgan_hyperparams['batch_size'], len(data_tensor))
    dataloader = DataLoader(TensorDataset(data_tensor),
                            batch_size=bs,
                            shuffle=True,
                            drop_last=len(data_tensor) >= bs * 2)

    # 3. 初始化网络与优化器
    latent_dim = wgan_hyperparams['latent_dim']
    netG = Generator(latent_dim, X_processed.shape[1]).to(current_device)
    netD = Critic(X_processed.shape[1]).to(current_device)
    optG = optim.RMSprop(netG.parameters(), lr=wgan_hyperparams['lr'])
    optD = optim.RMSprop(netD.parameters(), lr=wgan_hyperparams['lr'])

    # 4. 训练循环
    print(f"{log_prefix}] 开始WGAN-GP训练 ({wgan_hyperparams['epochs']} 轮)...")
    epochs = wgan_hyperparams['epochs']
    for epoch in range(epochs):
        for real, in dataloader:
            real = real.to(current_device)
            # 训练判别器
            for _ in range(wgan_hyperparams['n_critic']):
                z = torch.randn(real.size(0), latent_dim, device=current_device)
                fake = netG(z)
                lossD = (torch.mean(netD(fake)) -
                         torch.mean(netD(real))) + wgan_hyperparams['lambda_gp'] * gradient_penalty(netD, real, fake, current_device)
                optD.zero_grad(); lossD.backward(); optD.step()
            # 训练生成器
            z = torch.randn(real.size(0), latent_dim, device=current_device)
            lossG = -torch.mean(netD(netG(z)))
            optG.zero_grad(); lossG.backward(); optG.step()
        if (epoch + 1) % max(1, epochs // 10) == 0:
            print(f"{log_prefix}] [Epoch {epoch+1}/{epochs}] Critic Loss: {lossD.item():.4f}, Gen Loss: {lossG.item():.4f}")

    # 5. 生成新样本
    print(f"{log_prefix}] 训练完成，正在生成 {num_samples_to_generate} 个样本...")
    netG.eval()
    with torch.no_grad():
        z = torch.randn(num_samples_to_generate, latent_dim, device=current_device)
        gen_processed = netG(z).cpu().numpy()

    # 6. 后处理：反变换与数据还原
    num_cont_cols = len(CONTINUOUS_COLS)
    gen_cont_std = gen_processed[:, :num_cont_cols]
    gen_cat_oh   = gen_processed[:, num_cont_cols:]

    # 6a. 连续列反标准化 + 裁剪
    gen_cont = scaler.inverse_transform(gen_cont_std)
    for i, col in enumerate(CONTINUOUS_COLS):
        lo, hi = CONTINUOUS_BOUNDS[col]
        gen_cont[:, i] = np.clip(gen_cont[:, i], lo, hi)

    # 6b. 分类列反独热编码 + 映射到合法类别 (修正后的鲁棒版本)
    gen_cat_df = pd.DataFrame(columns=CATEGORICAL_COLS)
    current_col_idx = 0
    # encoder.categories_ 会按顺序存储每个分类特征的类别
    for i, col in enumerate(CATEGORICAL_COLS):
        # 获取当前特征的独热编码有多少列
        num_categories = len(encoder.categories_[i])
        # 切片出对应的独热编码部分
        col_slice = gen_cat_oh[:, current_col_idx : current_col_idx + num_categories]
        # 找到每行最大值的索引，这个索引就是类别在 categories_ 中的位置
        # argmax确保即使有多个最大值，也只返回第一个，避免歧义
        cat_indices = np.argmax(col_slice, axis=1)
        # 从 encoder 中恢复原始类别标签
        original_labels = encoder.categories_[i][cat_indices]
        gen_cat_df[col] = original_labels
        # 更新下一个特征的起始索引
        current_col_idx += num_categories

    # 将所有列转换为数值类型，为下一步映射做准备
    for col in CATEGORICAL_COLS:
        gen_cat_df[col] = pd.to_numeric(gen_cat_df[col], errors='coerce')

    # 对可能不是整数的生成值，映射到最近的合法类别
    for col in CATEGORICAL_COLS:
        gen_cat_df[col] = gen_cat_df[col].apply(
            lambda x: min(CATEGORY_MAPPINGS[col], key=lambda v: abs(v - x)) if pd.notna(x) else x
        )
    
    # 7. 合并为最终的DataFrame
    augmented_df = pd.DataFrame(gen_cont, columns=CONTINUOUS_COLS)
    augmented_df[CATEGORICAL_COLS] = gen_cat_df

    # 8. (可选) 保存到文件
    if output_augmented_data_path:
        augmented_df.to_excel(output_augmented_data_path, index=False)
        print(f"{log_prefix}] 增强数据已保存到: {output_augmented_data_path}")

    return augmented_df

# --- 2. 主流程开始 ---
try:
    development_df_original = pd.read_excel(DEV_SET_FILE)
    final_test_df_original = pd.read_excel(TEST_SET_FILE)
    print(f"成功加载数据: 开发集形状 {development_df_original.shape}, 最终测试集形状 {final_test_df_original.shape}")
except FileNotFoundError as e:
    print(f"错误: 数据文件未找到，请检查路径。 {e}")
    exit()

# 将开发集和最终测试集划分为特征(X)和标签(y)
X_dev_original = development_df_original.drop(columns=[TARGET_COLUMN])
y_dev_original = development_df_original[TARGET_COLUMN]
X_final_test = final_test_df_original.drop(columns=[TARGET_COLUMN])
y_final_test = final_test_df_original[TARGET_COLUMN]

# --- 步骤 3.1: 超参数调优 (HPO) ---
print("\n--- [步骤 3.1] XGBoost 超参数调优 ---")
print("为HPO生成开发集的增强版本...")
num_augmented_samples_for_hpo = len(development_df_original) * 1 # 生成与原始数据等量的样本
hpo_wgan_params = DEFAULT_WGAN_PARAMS.copy()
hpo_wgan_params['epochs'] = 1000 # HPO时可以适当减少轮数以节省时间

augmented_dev_for_hpo_df = train_and_generate_wgangp(
    input_original_df=development_df_original.copy(),
    wgan_hyperparams=hpo_wgan_params,
    num_samples_to_generate=num_augmented_samples_for_hpo,
    current_device=device,
    fold_num_for_logging="HPO_Dev_Set"
)

# 准备用于HPO的数据集 (原始 + 增强)
X_augmented_dev_for_hpo = augmented_dev_for_hpo_df.drop(columns=[TARGET_COLUMN])
y_augmented_dev_for_hpo = augmented_dev_for_hpo_df[TARGET_COLUMN]
X_combined_dev_for_hpo = pd.concat([X_dev_original, X_augmented_dev_for_hpo], ignore_index=True)
y_combined_dev_for_hpo = pd.concat([y_dev_original, y_augmented_dev_for_hpo], ignore_index=True)

# 初始化XGBoost分类器和RandomizedSearchCV
xgb_classifier_for_hpo = xgb.XGBClassifier(objective='binary:logistic', eval_metric='logloss',
                                           random_state=RANDOM_STATE, use_label_encoder=False,
                                           tree_method='gpu_hist' if device.type == 'cuda' else 'auto')

random_search_hpo = RandomizedSearchCV(
    estimator=xgb_classifier_for_hpo, param_distributions=XGB_PARAM_GRID,
    n_iter=N_ITER_RANDOMIZED_SEARCH, cv=N_SPLITS_KFOLD,
    scoring='roc_auc', # 使用AUC作为评估指标，对不平衡数据更稳健
    verbose=1, random_state=RANDOM_STATE, n_jobs=-1
)
print(f"开始在 {X_combined_dev_for_hpo.shape[0]} 个样本上进行XGBoost超参数搜索...")
random_search_hpo.fit(X_combined_dev_for_hpo, y_combined_dev_for_hpo)
best_overall_xgboost_params = random_search_hpo.best_params_
print(f"\n找到的最佳XGBoost超参数: {best_overall_xgboost_params}")
print(f"最佳HPO ROC AUC Score: {random_search_hpo.best_score_:.4f}")

# --- 步骤 3.2: K-折交叉验证与动态增强 ---
print(f"\n--- [步骤 3.2] 在开发集上进行 {N_SPLITS_KFOLD}-折交叉验证 (动态WGAN-GP增强) ---")
kf = KFold(n_splits=N_SPLITS_KFOLD, shuffle=True, random_state=RANDOM_STATE)

# 用于存储每折结果的列表
kfold_cv_val_metrics_list = []
kfold_cv_train_metrics_list = []
cv_train_loss_curves = []
cv_val_loss_curves = []

cv_wgan_params = DEFAULT_WGAN_PARAMS.copy()
augmentation_factor_cv = 1 # 在CV中，为训练集生成等量的增强数据

for fold_idx, (train_indices, val_indices) in enumerate(kf.split(development_df_original)):
    print(f"\n--- K-Fold: 第 {fold_idx + 1}/{N_SPLITS_KFOLD} 折 ---")
    cv_train_original_fold_df = development_df_original.iloc[train_indices]
    cv_val_original_fold_df = development_df_original.iloc[val_indices]

    # 在当前训练折上动态生成增强数据
    cv_augmented_fold_df = train_and_generate_wgangp(
        input_original_df=cv_train_original_fold_df.copy(),
        wgan_hyperparams=cv_wgan_params,
        num_samples_to_generate=len(cv_train_original_fold_df) * augmentation_factor_cv,
        current_device=device,
        fold_num_for_logging=f"Fold_{fold_idx + 1}"
    )

    # 准备当前折的训练集(原始+增强)和验证集(仅原始)
    X_cv_train_original_fold = cv_train_original_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_train_original_fold = cv_train_original_fold_df[TARGET_COLUMN]
    X_cv_augmented_fold = cv_augmented_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_augmented_fold = cv_augmented_fold_df[TARGET_COLUMN]
    X_cv_train_combined_fold = pd.concat([X_cv_train_original_fold, X_cv_augmented_fold], ignore_index=True)
    y_cv_train_combined_fold = pd.concat([y_cv_train_original_fold, y_cv_augmented_fold], ignore_index=True)
    
    X_cv_val_fold = cv_val_original_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_val_fold = cv_val_original_fold_df[TARGET_COLUMN]

    # 使用找到的最佳参数训练模型
    model_fold = xgb.XGBClassifier(
        **best_overall_xgboost_params, objective='binary:logistic', eval_metric='logloss',
        random_state=RANDOM_STATE, use_label_encoder=False,
        tree_method='gpu_hist' if device.type == 'cuda' else 'auto'
    )
    
    eval_set_fold = [(X_cv_train_combined_fold, y_cv_train_combined_fold), (X_cv_val_fold, y_cv_val_fold)]
    model_fold.fit(X_cv_train_combined_fold, y_cv_train_combined_fold,
                   eval_set=eval_set_fold, early_stopping_rounds=10, verbose=False)

    # 记录损失曲线
    fold_eval_results = model_fold.evals_result()
    cv_train_loss_curves.append(fold_eval_results['validation_0']['logloss'])
    cv_val_loss_curves.append(fold_eval_results['validation_1']['logloss'])

    # 在验证集上评估
    y_pred_val_proba = model_fold.predict_proba(X_cv_val_fold)[:, 1]
    y_pred_val = (y_pred_val_proba > 0.5).astype(int)
    kfold_cv_val_metrics_list.append({
        'Accuracy': accuracy_score(y_cv_val_fold, y_pred_val),
        'F1 Score': f1_score(y_cv_val_fold, y_pred_val),
        'AUC': roc_auc_score(y_cv_val_fold, y_pred_val_proba)
    })

    # 在训练集上评估 (用于监控过拟合)
    y_pred_train_proba = model_fold.predict_proba(X_cv_train_combined_fold)[:, 1]
    y_pred_train = (y_pred_train_proba > 0.5).astype(int)
    kfold_cv_train_metrics_list.append({
        'Accuracy': accuracy_score(y_cv_train_combined_fold, y_pred_train),
        'F1 Score': f1_score(y_cv_train_combined_fold, y_pred_train),
        'AUC': roc_auc_score(y_cv_train_combined_fold, y_pred_train_proba)
    })
    print(f"Fold {fold_idx + 1} - Val AUC: {kfold_cv_val_metrics_list[-1]['AUC']:.4f} | Train AUC: {kfold_cv_train_metrics_list[-1]['AUC']:.4f}")

# --- 步骤 3.3: 交叉验证结果分析与可视化 ---
# 计算平均性能
avg_kfold_cv_val_metrics_df = pd.DataFrame(kfold_cv_val_metrics_list)
avg_kfold_cv_train_metrics_df = pd.DataFrame(kfold_cv_train_metrics_list)

print("\n--- K-折交叉验证平均性能 (WGAN-GP动态增强) ---")
print("--- 平均验证集性能 ---")
print(avg_kfold_cv_val_metrics_df.mean())
print("\n--- 平均训练集性能 ---")
print(avg_kfold_cv_train_metrics_df.mean())

# 可视化1: 平均性能指标对比条形图
avg_val_metrics = avg_kfold_cv_val_metrics_df.mean()
avg_train_metrics = avg_kfold_cv_train_metrics_df.mean()
metrics_to_plot = ['Accuracy', 'F1 Score', 'AUC']
x_axis = np.arange(len(metrics_to_plot))

plt.figure(figsize=(10, 6))
plt.bar(x_axis - 0.2, avg_train_metrics[metrics_to_plot], width=0.4, label='CV Train Avg.', align='center')
plt.bar(x_axis + 0.2, avg_val_metrics[metrics_to_plot], width=0.4, label='CV Validation Avg.', align='center')
plt.xticks(x_axis, metrics_to_plot)
plt.ylabel('Score')
plt.title('Average K-Fold CV Train vs. Validation Metrics (WGAN-GP Augmented)')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.ylim(0.8, 1.0) # 根据实际情况调整Y轴范围，以更好地显示差异
plt.savefig(OUTPUT_PLOT_PATH / "kfold_avg_eval_metrics_wgangp.png", dpi=300)
plt.show()

# 可视化2: 每折的训练/验证损失曲线
plt.figure(figsize=(12, 7))
for i in range(N_SPLITS_KFOLD):
    plt.plot(cv_train_loss_curves[i], label=f'Train Fold {i+1}', linestyle='--')
    plt.plot(cv_val_loss_curves[i], label=f'Validation Fold {i+1}', linestyle='-')
plt.xlabel('Boosting Round')
plt.ylabel('Log Loss')
plt.title('Training and Validation Log Loss per Fold (WGAN-GP Augmented)')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.savefig(OUTPUT_PLOT_PATH / "kfold_logloss_per_fold_wgangp.png", dpi=300)
plt.show()

# --- 步骤 4: 训练最终模型 ---
print("\n--- [步骤 4] 训练最终模型 ---")
print("为最终模型生成开发集的完整增强版本...")
final_wgan_params = DEFAULT_WGAN_PARAMS.copy()
final_wgan_params['epochs'] = 2000 # 最终模型使用更多的训练轮数
num_augmented_samples_final = len(development_df_original) * 2 # 生成2倍的增强数据

final_augmented_dev_df = train_and_generate_wgangp(
    input_original_df=development_df_original.copy(),
    wgan_hyperparams=final_wgan_params,
    num_samples_to_generate=num_augmented_samples_final,
    current_device=device,
    fold_num_for_logging="Final_Model_Augmentation",
    output_augmented_data_path=AUGMENTED_DATA_OUTPUT_FOLDER / "augmented_for_final_model.xlsx"
)

# 准备最终训练数据
X_final_augmented_dev = final_augmented_dev_df.drop(columns=[TARGET_COLUMN])
y_final_augmented_dev = final_augmented_dev_df[TARGET_COLUMN]
X_train_final_model = pd.concat([X_dev_original, X_final_augmented_dev], ignore_index=True)
y_train_final_model = pd.concat([y_dev_original, y_final_augmented_dev], ignore_index=True)
print(f"用于训练最终模型的总数据形状: {X_train_final_model.shape}")

# 初始化并训练最终模型
final_model = xgb.XGBClassifier(
    **best_overall_xgboost_params, objective='binary:logistic', eval_metric='logloss',
    random_state=RANDOM_STATE, use_label_encoder=False,
    tree_method='gpu_hist' if device.type == 'cuda' else 'auto'
)
print("开始训练最终模型...")
final_model.fit(X_train_final_model, y_train_final_model)
print("最终模型训练完成。")

# 保存最终模型
final_model_path = MODEL_OUTPUT_PATH / "final_xgboost_wgangp_model.joblib"
joblib.dump(final_model, final_model_path)
print(f"最终模型已保存到: {final_model_path}")

# 可视化3: 最终模型特征重要性
plt.figure(figsize=(10, 8))
xgb.plot_importance(final_model, max_num_features=20, height=0.8, title="Feature Importance (Final Model)")
plt.tight_layout()
plt.savefig(OUTPUT_PLOT_PATH / "final_model_feature_importances.png", dpi=300)
plt.show()

# --- 步骤 5: 在最终测试集上进行无偏评估 ---
print("\n--- [步骤 5] 在最终测试集上进行无偏评估 ---")
y_pred_proba_final = final_model.predict_proba(X_final_test)[:, 1]
y_pred_final = (y_pred_proba_final >= 0.5).astype(int)

print("--- 最终模型在最终测试集上的性能 ---")
print(f"Accuracy : {accuracy_score(y_final_test, y_pred_final):.4f}")
print(f"Precision: {precision_score(y_final_test, y_pred_final):.4f}")
print(f"Recall   : {recall_score(y_final_test, y_pred_final):.4f}")
print(f"F1 Score : {f1_score(y_final_test, y_pred_final):.4f}")
print(f"AUC      : {roc_auc_score(y_final_test, y_pred_proba_final):.4f}")

# 可视化4: 最终测试集上的混淆矩阵
fig, ax = plt.subplots(figsize=(8, 8))
ConfusionMatrixDisplay.from_predictions(y_final_test, y_pred_final,
                                        ax=ax,
                                        display_labels=['Absence', 'Presence'],
                                        cmap='Blues')
ax.set_title('Confusion Matrix on Final Test Set')
plt.savefig(OUTPUT_PLOT_PATH / "final_model_confusion_matrix.png", dpi=300)
plt.show()

print("\n--- 整体流程执行完毕 ---")
