In [None]:
import pandas as pd
import numpy as np
import joblib
import shap
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

train_path = r"D:\临床数据\NHANES数据清洗\train_imputed.csv"
val_path = r"D:\临床数据\NHANES数据清洗\val_imputed.csv"
model_path = r"D:\临床数据\NHANES数据清洗\LR-生活方式+检验学指标-3年心因死亡.joblib"

df_train = pd.read_csv(train_path, low_memory=False)
df_val = pd.read_csv(val_path, low_memory=False)
estimator = joblib.load(model_path)

x_train = df_train.drop(columns=["3 year heart death", "SEQN"], errors="ignore")
y_train = df_train["3 year heart death"]
x_test = df_val.drop(columns=["3 year heart death", "SEQN"], errors="ignore")
y_test = df_val["3 year heart death"]

plt.rcParams['font.family'] = 'Arial'
shap.initjs()
explainer = shap.LinearExplainer(estimator, x_train)
shap_values = explainer.shap_values(x_test)

feature_group = {
    "Doctor told you have diabetes": ["Doctor told you have diabetes_2", "Doctor told you have diabetes_3"],
    "Education Level - Adults 20+": ["Education Level - Adults 20+_2", "Education Level - Adults 20+_3", "Education Level - Adults 20+_4", "Education Level - Adults 20+_5"],
    "Ever told you had a stroke": ["Ever told you had a stroke_2"],
    "Ever told you had coronary heart disease": ["Ever told you had coronary heart disease_2"],
    "Ever told you had high blood pressure": ["Ever told you had high blood pressure_2"],
    "Had at least 12 alcohol drinks/1 yr?": ["Had at least 12 alcohol drinks/1 yr?_2"],
    "Marital Status": ["Marital Status_2", "Marital Status_3", "Marital Status_4", "Marital Status_5", "Marital Status_6"],
    "Moderate recreational activities": ["Moderate recreational activities_2"],
    "Moderate work activity": ["Moderate work activity_2"],
    "Race": ["Race_2", "Race_3", "Race_4", "Race_5"],
    "Smoked at least 100 cigarettes in life": ["Smoked at least 100 cigarettes in life_2"],
    "Vigorous recreational activities": ["Vigorous recreational activities_2"],
    "Vigorous work activity": ["Vigorous work activity_2"],
    "Walk or bicycle": ["Walk or bicycle_2"],
    "Age": ["Age"], 
    "Minutes sedentary activity": ["Minutes sedentary activity"], 
    "Ratio of family income to poverty": ["Ratio of family income to poverty"],
    "Waist Circumference (cm)": ["Waist Circumference (cm)"],
    "Alanine Aminotransferase (ALT) (U/L)": ["Alanine Aminotransferase (ALT) (U/L)"],
    "Albumin (g/L)": ["Albumin (g/L)"],
    "Albumin_urine (mg/L)": ["Albumin_urine (mg/L)"], 
    "Alkaline Phosphatase (ALP) (IU/L)": ["Alkaline Phosphatase (ALP) (IU/L)"],
    "Bicarbonate (mmol/L)": ["Bicarbonate (mmol/L)"],
    "Blood urea nitrogen (mmol/L)": ["Blood urea nitrogen (mmol/L)"],
    "Chloride (mmol/L)": ["Chloride (mmol/L)"],
    "Cholesterol (mmol/L)": ["Cholesterol (mmol/L)"],
    "Creatinine (umol/L)": ["Creatinine (umol/L)", "Creatinine (?mol/L)"],
    "Creatinine_urine (umol/L)": ["Creatinine_urine (umol/L)"],
    "Direct HDL-Cholesterol (mmol/L)": ["Direct HDL-Cholesterol (mmol/L)"],
    "Gamma Glutamyl Transferase (GGT) (U/L)": ["Gamma Glutamyl Transferase (GGT) (U/L)"],
    "Globulin (g/L)": ["Globulin (g/L)"],
    "Glucose_serum (mmol/L)": ["Glucose_serum (mmol/L)"],
    "Glycohemoglobin (%)": ["Glycohemoglobin (%)"],
    "Hematocrit (%)": ["Hematocrit (%)"],
    "Hemoglobin (g/dL)": ["Hemoglobin (g/dL)"],
    "Iron_refigerated (umol/L)": ["Iron_refigerated (umol/L)"],
    "Lactate Dehydrogenase (LDH) (U/L)": ["Lactate Dehydrogenase (LDH) (U/L)"],
    "Lymphocyte number (1000 cells/uL)": ["Lymphocyte number (1000 cells/uL)"],
    "Lymphocyte percent (%)": ["Lymphocyte percent (%)"],
    "Mean cell hemoglobin (pg)": ["Mean cell hemoglobin (pg)"],
    "Mean cell hemoglobin concentration (g/dL)": ["Mean cell hemoglobin concentration (g/dL)"],
    "Mean cell volume (fL)": ["Mean cell volume (fL)"],
    "Monocyte number (1000 cells/uL)": ["Monocyte number (1000 cells/uL)"],
    "Monocyte percent (%)": ["Monocyte percent (%)"],
    "Osmolality (mmol/Kg)": ["Osmolality (mmol/Kg)"],
    "Phosphorus (mmol/L)": ["Phosphorus (mmol/L)"],
    "Platelet count (1000 cells/uL)": ["Platelet count (1000 cells/uL)"],
    "Potassium (mmol/L)": ["Potassium (mmol/L)"],
    "Red blood cell count (million cells/uL)": ["Red blood cell count (million cells/uL)"],
    "Red cell distribution width (%)": ["Red cell distribution width (%)"],
    "Segmented neutrophils num (1000 cell/uL)": ["Segmented neutrophils num (1000 cell/uL)"],
    "Segmented neutrophils percent (%)": ["Segmented neutrophils percent (%)"],
    "Total protein (g/L)": ["Total protein (g/L)"],
    "Uric acid (umol/L)": ["Uric acid (umol/L)"]
}

aggregated_features = pd.DataFrame(index=x_test.index)
aggregated_shap_values = np.zeros((len(x_test), len(feature_group)))

for i, (feature_name, columns) in enumerate(feature_group.items()):
    valid_cols = [c for c in columns if c in x_test.columns]
    if valid_cols:
        aggregated_features[feature_name] = x_test[valid_cols].max(axis=1)
        aggregated_shap_values[:, i] = np.max(shap_values[:, [x_test.columns.get_loc(c) for c in valid_cols]], axis=1)
    else:
        print(f"Warning: {feature_name} 的列 {columns} 未在 x_test 中找到")

fig1 = plt.figure()
shap.summary_plot(aggregated_shap_values, features=aggregated_features, feature_names=list(feature_group.keys()))
pdf_path = r"D:\临床数据\NHANES数据清洗\LR-SHAP图-3年心因死亡.pdf"
with PdfPages(pdf_path) as pdf:
    pdf.savefig(fig1)
    plt.close()

shap_values_mean_abs = np.mean(np.abs(aggregated_shap_values), axis=0)
feature_importance = pd.DataFrame({
    'Feature': list(feature_group.keys()),
    'Importance': shap_values_mean_abs
}).sort_values(by='Importance', ascending=False)

output_excel = r"D:\临床数据\NHANES数据清洗\LR-SHAP值-3年心因死亡.xlsx"
feature_importance.to_excel(output_excel, index=False)