In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, accuracy_score, classification_report
from xgboost import XGBClassifier
from sklearn.svm import SVC
from imblearn.over_sampling import SMOTE
import shap
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
plt.rcParams['font.family'] = 'serif'

In [2]:
general_information = pd.read_excel('/home/jiangjingwen/Downloads/LiHui/data/1.一般资料.xlsx')

In [3]:
all_feature = pd.read_excel('/home/jiangjingwen/Downloads/LiHui/data/8.feature_merged_prognosis.xlsx')

In [4]:
all_feature = all_feature.drop(columns=['patients'])

In [5]:
grouped = all_feature.groupby(all_feature.columns[0])
group_labels = list(grouped.groups.keys())
hc_data = grouped.get_group(group_labels[0])
coma_data = grouped.get_group(group_labels[1])
awake_data = grouped.get_group(group_labels[2])

coma_awake_data = pd.concat([coma_data, awake_data], ignore_index=True)
y = coma_awake_data['status_after_6_months']
label2id = {1:0., 2:1.}
y = y.map(label2id)

In [6]:
datasets = {
    "HC": hc_data,
    "COMA": coma_data,
    "AWAKE": awake_data
}

data_ranges = {
    "Power_Spectrum": slice(3, 33),
    "Microstate": slice(33, 81),
    "Audio": slice(81, None)
}

In [7]:
split_dataset = {}
for dataset_name, dataset in datasets.items():
    for range_name, data_range in data_ranges.items():
        subset_name = '_'.join([dataset_name, range_name])
        data_subset = dataset.iloc[:, data_range]
        split_dataset[subset_name] = data_subset

In [8]:
HC_Power_Spectrum = split_dataset['HC_Power_Spectrum']
HC_Microstate = split_dataset['HC_Microstate']
HC_Audio = split_dataset['HC_Audio']
COMA_Power_Spectrum = split_dataset['COMA_Power_Spectrum']
COMA_Microstate = split_dataset['COMA_Microstate']
COMA_Audio = split_dataset['COMA_Audio']
AWAKE_Power_Spectrum = split_dataset['AWAKE_Power_Spectrum']
AWAKE_Microstate = split_dataset['AWAKE_Microstate']
AWAKE_Audio = split_dataset['AWAKE_Audio']

COMA_AWAKE_Power_Spectrum = pd.concat([COMA_Power_Spectrum, AWAKE_Power_Spectrum], ignore_index=True)
COMA_AWAKE_Microstate = pd.concat([COMA_Microstate, AWAKE_Microstate], ignore_index=True)
COMA_AWAKE_Audio = pd.concat([COMA_Audio, AWAKE_Audio], ignore_index=True)

In [9]:
def plot_roc_curve_and_compute_auc_with_model(X, y, model, test_size=0.4, random_state=42):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
    model.fit(X_train, y_train)
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
    roc_auc = auc(fpr, tpr)

    return fpr, tpr, roc_auc, X_train, X_test, y_train, y_test

In [10]:
X_list = [COMA_AWAKE_Power_Spectrum, COMA_AWAKE_Microstate, COMA_AWAKE_Audio]
X_name_list = ['COMA_AWAKE_Power_Spectrum', 'COMA_AWAKE_Microstate', 'COMA_AWAKE_Audio']
model_list = [XGBClassifier(), RandomForestClassifier(n_estimators=100, random_state=42), SVC(kernel='rbf', probability=True, random_state=42)]
model_name_list = ['XGB', 'RandomForest', 'SVC']

In [11]:
for i, (X, X_name) in enumerate(zip(X_list, X_name_list)):
    feature_names = X.columns
    X = np.array(X)
    y = np.array(y)
    fig, axs = plt.subplots(1, len(model_list), figsize=(len(model_list)*8, 6))
    fig, axs2 = plt.subplots(1, len(model_list), figsize=(len(model_list)*8, 6))
    fig, axs3 = plt.subplots(1, len(model_list), figsize=(len(model_list)*8, 6))
    for j, (model, model_name) in enumerate(zip(model_list, model_name_list)):
        fpr, tpr, roc_auc, X_train, X_test, y_train, y_test = plot_roc_curve_and_compute_auc_with_model(X, y, model)
        axs[j].plot(fpr, tpr, lw=2, color='darkorange', label=f'{model_name} (AUC = {roc_auc:.2f})')
        axs[j].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        axs[j].set_xlim([0.0, 1.0])
        axs[j].set_ylim([0.0, 1.05])
        axs[j].set_xlabel('False Positive Rate')
        axs[j].set_ylabel('True Positive Rate')
        axs[j].set_title(f'{X_name}')
        axs[j].legend(loc="upper left")
        if model_name != 'SVC':
            feature_importances = model.feature_importances_
            feature_importance_df = pd.DataFrame({'feature': feature_names, 'importance': feature_importances})
            feature_importance_df = feature_importance_df.sort_values(by='importance', ascending=False)
            top_10_features = feature_importance_df.head(10)
            axs2[j].barh(top_10_features['feature'][::-1], top_10_features['importance'][::-1], color='skyblue')
            axs2[j].set_xlabel('Importance')
            axs2[j].set_title(f'Top 10 Feature Importances in {model_name}')
            
            explainer = shap.TreeExplainer(model)
            shap_values = explainer.shap_values(X_test)
            if model_name == 'RandomForest':
                shap_values = np.abs(shap_values).mean(axis=-1)
            shap_values_summary = np.abs(shap_values).mean(axis=0)
            # 排序特征和SHAP值
            sorted_indices = np.argsort(shap_values_summary)
            top_features = feature_names[sorted_indices][-10:]  # 取出排名前10的特征
            top_shap_values = shap_values_summary[sorted_indices][-10:]
            axs3[j].barh(top_features, top_shap_values, color='skyblue')
            axs3[j].set_xlabel('SHAP Value')
            axs3[j].set_title(f'Top 10 SHAP Values in {model_name}')
        else:
            # 使用KernelExplainer解释模型
            explainer = shap.KernelExplainer(model.predict, X_train)
            shap_values = explainer.shap_values(X)
            shap_values_summary = np.abs(shap_values).mean(axis=0)
            
            # 排序特征和SHAP值
            sorted_indices = np.argsort(shap_values_summary)
            top_features = feature_names[sorted_indices][-10:]  # 取出排名前10的特征
            top_shap_values = shap_values_summary[sorted_indices][-10:]
            axs3[j].barh(top_features, top_shap_values, color='skyblue')
            axs3[j].set_xlabel('SHAP Value')
            axs3[j].set_title(f'Top 10 SHAP Values in {model_name}')
        
plt.tight_layout()
plt.show()

In [12]:
split_dataset_scaled = {}
scaler = MinMaxScaler()
for dataset_name, dataset in datasets.items():
    for range_name, data_range in data_ranges.items():
        subset_name = '_'.join([dataset_name, range_name, 'scaled'])
        data_subset = dataset.iloc[:, data_range]
        split_dataset_scaled[subset_name] = pd.DataFrame(scaler.fit_transform(data_subset), columns=data_subset.columns)

In [13]:
split_dataset_scaled.keys()

In [14]:
HC_Power_Spectrum_scaled = split_dataset_scaled['HC_Power_Spectrum_scaled']
HC_Microstate_scaled = split_dataset_scaled['HC_Microstate_scaled']
HC_Audio_scaled = split_dataset_scaled['HC_Audio_scaled']
COMA_Power_Spectrum_scaled = split_dataset_scaled['COMA_Power_Spectrum_scaled']
COMA_Microstate_scaled = split_dataset_scaled['COMA_Microstate_scaled']
COMA_Audio_scaled = split_dataset_scaled['COMA_Audio_scaled']
AWAKE_Power_Spectrum_scaled = split_dataset_scaled['AWAKE_Power_Spectrum_scaled']
AWAKE_Microstate_scaled = split_dataset_scaled['AWAKE_Microstate_scaled']
AWAKE_Audio_scaled = split_dataset_scaled['AWAKE_Audio_scaled']

COMA_AWAKE_Power_Spectrum_scaled = pd.concat([COMA_Power_Spectrum_scaled, AWAKE_Power_Spectrum_scaled], ignore_index=True)
COMA_AWAKE_Microstate = pd.concat([COMA_Microstate_scaled, AWAKE_Microstate_scaled], ignore_index=True)
COMA_AWAKE_Audio_scaled = pd.concat([COMA_Audio_scaled, AWAKE_Audio_scaled], ignore_index=True)

In [15]:
X_list = [COMA_AWAKE_Power_Spectrum_scaled, COMA_AWAKE_Microstate, COMA_AWAKE_Audio_scaled]
X_name_list = ['COMA_AWAKE_Power_Spectrum_scaled', 'COMA_Microstate_scaled', 'COMA_AWAKE_Audio_scaled']
model_list = [XGBClassifier(), RandomForestClassifier(n_estimators=100, random_state=42), SVC(kernel='rbf', probability=True, random_state=42)]
model_name_list = ['XGB', 'RandomForest', 'SVC']

In [16]:
for i, (X, X_name) in enumerate(zip(X_list, X_name_list)):
    feature_names = X.columns
    X = np.array(X)
    y = np.array(y)
    fig, axs = plt.subplots(1, len(model_list), figsize=(len(model_list)*8, 6))
    fig, axs2 = plt.subplots(1, len(model_list), figsize=(len(model_list)*8, 6))
    fig, axs3 = plt.subplots(1, len(model_list), figsize=(len(model_list)*8, 6))
    for j, (model, model_name) in enumerate(zip(model_list, model_name_list)):
        fpr, tpr, roc_auc, X_train, X_test, y_train, y_test = plot_roc_curve_and_compute_auc_with_model(X, y, model)
        axs[j].plot(fpr, tpr, lw=2, color='darkorange', label=f'{model_name} (AUC = {roc_auc:.2f})')
        axs[j].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        axs[j].set_xlim([0.0, 1.0])
        axs[j].set_ylim([0.0, 1.05])
        axs[j].set_xlabel('False Positive Rate')
        axs[j].set_ylabel('True Positive Rate')
        axs[j].set_title(f'{X_name}')
        axs[j].legend(loc="upper left")
        if model_name != 'SVC':
            feature_importances = model.feature_importances_
            feature_importance_df = pd.DataFrame({'feature': feature_names, 'importance': feature_importances})
            feature_importance_df = feature_importance_df.sort_values(by='importance', ascending=False)
            top_10_features = feature_importance_df.head(10)
            axs2[j].barh(top_10_features['feature'][::-1], top_10_features['importance'][::-1], color='skyblue')
            axs2[j].set_xlabel('Importance')
            axs2[j].set_title(f'Top 10 Feature Importances in {model_name}')
            
            explainer = shap.TreeExplainer(model)
            shap_values = explainer.shap_values(X_test)
            if model_name == 'RandomForest':
                shap_values = np.abs(shap_values).mean(axis=-1)
            shap_values_summary = np.abs(shap_values).mean(axis=0)
            # 排序特征和SHAP值
            sorted_indices = np.argsort(shap_values_summary)
            top_features = feature_names[sorted_indices][-10:]  # 取出排名前10的特征
            top_shap_values = shap_values_summary[sorted_indices][-10:]
            axs3[j].barh(top_features, top_shap_values, color='skyblue')
            axs3[j].set_xlabel('SHAP Value')
            axs3[j].set_title(f'Top 10 SHAP Values in {model_name}')
        else:
            # 使用KernelExplainer解释模型
            explainer = shap.KernelExplainer(model.predict, X_train)
            shap_values = explainer.shap_values(X)
            shap_values_summary = np.abs(shap_values).mean(axis=0)
            
            # 排序特征和SHAP值
            sorted_indices = np.argsort(shap_values_summary)
            top_features = feature_names[sorted_indices][-10:]  # 取出排名前10的特征
            top_shap_values = shap_values_summary[sorted_indices][-10:]
            axs3[j].barh(top_features, top_shap_values, color='skyblue')
            axs3[j].set_xlabel('SHAP Value')
            axs3[j].set_title(f'Top 10 SHAP Values in {model_name}')
        
plt.tight_layout()
plt.show()

In [17]:
X_list = [COMA_AWAKE_Power_Spectrum_scaled, COMA_AWAKE_Microstate, COMA_AWAKE_Audio_scaled]
X_name_list = ['COMA_AWAKE_Power_Spectrum_scaled', 'COMA_Microstate_scaled', 'COMA_AWAKE_Audio_scaled']
model = SVC(kernel='rbf', probability=True, random_state=42)

In [18]:
fpr_list = []
tpr_list = []
roc_auc_list = []
for i, (X, X_name) in enumerate(zip(X_list, X_name_list)):
    X = np.array(X)
    y = np.array(y)
    fpr, tpr, roc_auc, X_train, X_test, y_train, y_test = plot_roc_curve_and_compute_auc_with_model(X, y, model)
    fpr_list.append(fpr)
    tpr_list.append(tpr)
    roc_auc_list.append(roc_auc)

In [19]:
plt.rcParams['font.family'] = 'serif'
font_size = 12
plt.rcParams.update({'font.size': font_size, 'font.weight': 'bold'})
plt.figure()
plt.plot(fpr_list[0], tpr_list[0], color='limegreen', lw=2, label='Power_Spectrum (AUC = %0.2f)' % roc_auc_list[0])
plt.plot(fpr_list[2], tpr_list[2], color='royalblue', lw=2, label='ERP (AUC = %0.2f)' % roc_auc_list[2])
plt.plot(fpr_list[1], tpr_list[1], color='darkorange', lw=2, label='Microstate (AUC = %0.2f)' % roc_auc_list[1])
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=font_size, fontweight='bold')
plt.ylabel('True Positive Rate', fontsize=font_size, fontweight='bold')
plt.title('ROC Curve', fontsize=font_size, fontweight='bold')
plt.legend(loc="lower right")
plt.savefig('PS_Microstate_ERP_roc.jpg', dpi=400)
plt.show()

In [20]:
# ------------------------------------------------------
# 人工筛选特征
# ------------------------------------------------------

In [21]:
select_power_spectrum_scaled = COMA_AWAKE_Power_Spectrum_scaled.loc[:, ['frontal_β', 'central_δ', 'parietal_δ']]
select_microstate = COMA_AWAKE_Microstate.loc[:, ['MeanDur_A', 'TimeCov_E', 'SegDensity_F', 'C_ToF', 'F_ToC', 'D_ToC', 'C_ToD']]
select_audio = COMA_AWAKE_Audio_scaled.loc[:, ['MMNFZ__amplitude']]

In [22]:
select_feat = pd.concat([select_power_spectrum_scaled, select_microstate, select_audio], axis=1)
feature_names = select_feat.columns

In [23]:
X = np.array(select_feat)

In [24]:
y

In [25]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)

In [26]:
# 构建和训练 SVM 模型
model = SVC(kernel='rbf', probability=True, random_state=42)
model.fit(X_train, y_train)
# 预测概率
y_pred_proba = model.predict_proba(X_test)[:, 1]

# 计算 ROC 曲线和 AUC
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

# 绘制 ROC 曲线
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (AUC = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()

In [27]:
# 使用KernelExplainer解释模型
# explainer = shap.Explainer(model.predict, X_train)
explainer = shap.KernelExplainer(model.predict, X_train)
shap_values = explainer.shap_values(X)

In [28]:
shap_values.shape

In [29]:
# 汇总解释结果
shap.summary_plot(shap_values, X, feature_names=feature_names, max_display=15)

In [30]:
shap.summary_plot(shap_values, X, feature_names=feature_names, plot_type='bar', max_display=15)