In [None]:
from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict
from sklearn.metrics import (
    get_scorer, accuracy_score, recall_score, precision_score,
    roc_auc_score, matthews_corrcoef, average_precision_score
)
from tqdm import tqdm
import pandas as pd
import numpy as np
import scipy.stats
import copy

class Scrambler:
    def __init__(self, model, iterations=100):
        """初始化Scrambler类"""
        self.base_model = model
        self.iterations = iterations
        self.progress_bar = False

    def validate(self, X, Y, method="train_test_split", scoring="accuracy", cross_val_score_aggregator="mean", pvalue_threshold=0.05, cv_kfolds=5, as_df=False, validation_data=None, progress_bar=False):
        """主要验证方法"""
        model_scorer = get_scorer(scoring)

        # 根据验证方法选择对应的验证函数
        if method == "train_test_split":
            result = self.__validate_train_test_split(X, Y, model_scorer, progress_bar)
        elif method == "cross_validation":
            result = self.__validate_cross_validation(X, Y, model_scorer, cross_val_score_aggregator, cv_kfolds, progress_bar)

        # 计算所有评分的z分数和p值
        all_metrics = {key: np.array(value) for key, value in result.items()}
        zscores = {key: scipy.stats.zscore(value) for key, value in all_metrics.items()}
        pvalues = {key: scipy.stats.norm.sf(abs(zscore)) * 2 for key, zscore in zscores.items()}
        significances = {key: value <= pvalue_threshold for key, value in pvalues.items()}

        if as_df:
            # 将评分、z分数、p值和显著性标志整合到DataFrame
            return pd.DataFrame({**all_metrics, **zscores, **pvalues, **significances})
        else:
            # 以元组形式返回结果
            return (*all_metrics.values(), *zscores.values(), *pvalues.values(), *significances.values())

    def __validate_train_test_split(self, X, Y, scorer, progress_bar):
        """使用训练测试分割进行验证"""
        X_train, X_test, Y_train, Y_test = train_test_split(X, Y)
        return self.__evaluate_model(X_train, Y_train, X_test, Y_test, scorer, progress_bar)

    def __validate_cross_validation(self, X, Y, scorer, aggregation, cv_kfolds, progress_bar):
        """使用交叉验证进行验证"""
        return self.__evaluate_model(X, Y, X, Y, scorer, progress_bar, True, cv_kfolds)

    def __evaluate_model(self, X_train, Y_train, X_test, Y_test, scorer, progress_bar, cross_val=False, cv_kfolds=5):
        """评估模型，计算各项评价指标"""
        self.base_model.fit(X_train, Y_train)

        # 计算原始模型评分
        if cross_val:
            Y_pred = cross_val_predict(self.base_model, X_train, Y_train, cv=cv_kfolds)
        else:
            Y_pred = self.base_model.predict(X_test)

        # 初始化各项评价指标列表
        metrics = {
            "accuracy": [accuracy_score(Y_test, Y_pred)],
            "recall": [recall_score(Y_test, Y_pred, average='binary')],
            "precision": [precision_score(Y_test, Y_pred, average='binary')],
            "roc_auc": [roc_auc_score(Y_test, Y_pred)],
            "mcc": [matthews_corrcoef(Y_test, Y_pred)],
            "avg_precision": [average_precision_score(Y_test, Y_pred)]
        }

        # 迭代扰乱模型并计算评价指标
        for _ in tqdm(range(self.iterations), disable=not progress_bar):
            Y_train_scrambled = np.random.permutation(Y_train)
            self.base_model.fit(X_train, Y_train_scrambled)

            if cross_val:
                Y_pred_scrambled = cross_val_predict(self.base_model, X_train, Y_train_scrambled, cv=cv_kfolds)
            else:
                Y_pred_scrambled = self.base_model.predict(X_test)

            metrics["accuracy"].append(accuracy_score(Y_test, Y_pred_scrambled))
            metrics["recall"].append(recall_score(Y_test, Y_pred_scrambled, average='binary'))
            metrics["precision"].append(precision_score(Y_test, Y_pred_scrambled, average='binary'))
            metrics["roc_auc"].append(roc_auc_score(Y_test, Y_pred_scrambled))
            metrics["mcc"].append(matthews_corrcoef(Y_test, Y_pred_scrambled))
            metrics["avg_precision"].append(average_precision_score(Y_test, Y_pred_scrambled))

        return metrics