In [None]:
import pandas as pd
from pathlib import Path

current_working_dir = Path.cwd()
data_file_path = current_working_dir.parent / "data" 
# Load datasets
### BuildingCohort
X_train = pd.read_parquet(f'{data_file_path}/X_train.parquet') # training set
y_train = pd.read_parquet(f'{data_file_path}/y_train.parquet')
X_test = pd.read_parquet(f'{data_file_path}/X_test.parquet') # tuning set
y_test = pd.read_parquet(f'{data_file_path}/y_test.parquet')
X_val = pd.read_parquet(f'{data_file_path}/X_val.parquet') # validation set
y_val = pd.read_parquet(f'{data_file_path}/y_val.parquet')

In [None]:
# Model paths
model_dir = current_working_dir.parent / "models"
models_paths = {
    'Hb': f'{model_dir}/Hb_LightGBM_Calibrated.joblib',
    'PLT': f'{model_dir}/PLT_LightGBM_Calibrated.joblib',
    'WBC_Neut': f'{model_dir}/WBC_Neut_LightGBM_Calibrated.joblib'
}
LGBM_features = f'{model_dir}/lightgbm_feature_names.joblib'
# Scaler path
scaler_path = f'{model_dir}/scaler_continuous.joblib'
scaler_features_path = f'{model_dir}/scaler_continuous_features.joblib'

In [None]:
import joblib

scaler = joblib.load(scaler_path)
scaler_feature_names = joblib.load(scaler_features_path)
lgbm_feature_names = joblib.load(LGBM_features)

In [None]:
X_Building = pd.concat([X_train, X_test, X_val], axis=0, ignore_index=True)
y_Building = pd.concat([y_train, y_test, y_val], axis=0, ignore_index=True)

X_Building = X_Building.reindex(columns=lgbm_feature_names)

In [None]:
from sklearn.linear_model import LogisticRegression

class PlattScalingCalibrator:
    def __init__(self, base_model):
        self.base_model = base_model
        self.platt_lr = LogisticRegression(max_iter=1000)
    
    def fit(self, X, y):
        if hasattr(self.base_model, "predict_proba"):
            raw_probs = self.base_model.predict_proba(X)[:, 1]
        else:
            raw_probs = self.base_model.predict(X)
        self.platt_lr.fit(raw_probs.reshape(-1, 1), y)
        return self
    
    def predict_proba(self, X):
        if hasattr(self.base_model, "predict_proba"):
            raw_probs = self.base_model.predict_proba(X)[:, 1]
        else:
            raw_probs = self.base_model.predict(X)
        calibrated_probs = self.platt_lr.predict_proba(raw_probs.reshape(-1, 1))
        return calibrated_probs

In [None]:
import os
import shap
import matplotlib.pyplot as plt

save_dir = current_working_dir.parent / "results" / "SHAP" 
os.makedirs(save_dir, exist_ok=True)

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['font.size'] = 10

outcome_map = {
    "Hb": "outcome_Hb",
    "PLT": "outcome_PLT",
    "WBC_Neut": "outcome_WBC_Neut"
}

def align_features(X, feature_names):
    # 按模型特征顺序对齐
    return X.reindex(columns=feature_names)

def apply_scaler(X):
    # 只对训练时连续变量做标准化
    X_out = X.copy().astype(float)
    X_scale = X_out.reindex(columns=scaler_feature_names)
    X_scaled = scaler.transform(X_scale)
    X_out.loc[:, scaler_feature_names] = X_scaled
    return X_out

def get_base_model(model):
    # 提取校准模型里的 base_model
    return model.base_model if hasattr(model, "base_model") else model

def compute_shap(model, X_scaled, sample_size=10000, seed=42):
    # 采样并计算 SHAP
    X_use = X_scaled.sample(min(sample_size, len(X_scaled)), random_state=seed)
    explainer = shap.TreeExplainer(get_base_model(model))
    shap_values = explainer.shap_values(X_use)
    if isinstance(shap_values, list):
        shap_values = shap_values[1]
    return shap_values, X_use

X_Building_raw = pd.concat([X_train, X_test, X_val], axis=0, ignore_index=True)
X_Building_raw = align_features(X_Building_raw, lgbm_feature_names)
X_Building_scaled = apply_scaler(X_Building_raw)

shap_store = {}
for outcome_name, model_path in models_paths.items():
    model = joblib.load(model_path)
    shap_values, X_sample_scaled = compute_shap(model, X_Building_scaled, sample_size=10000)
    X_sample_raw = X_Building_raw.loc[X_sample_scaled.index]
    shap_store[outcome_name] = (shap_values, X_sample_raw)

for outcome_name in ["Hb", "PLT", "WBC_Neut"]:
    shap_values, X_plot = shap_store[outcome_name]
    plt.figure(figsize=(7, 5))
    shap.summary_plot(
        shap_values,
        X_plot,
        max_display=15,
        plot_type="dot",
        show=False,
        color_bar_label="Feature Value"
    )
    plt.title(f"{outcome_name} SHAP Summary (Top15)", fontsize=12, fontweight="bold")
    plt.tight_layout()
    plt.savefig(f"{save_dir}/{outcome_name}_SHAP_Beeswarm.pdf", format="pdf", dpi=300, bbox_inches="tight")
    plt.close()

for outcome_name in ["Hb", "PLT", "WBC_Neut"]:
    shap_values, X_plot = shap_store[outcome_name]
    plt.figure(figsize=(7, 5))
    shap.summary_plot(
        shap_values,
        X_plot,
        max_display=15,
        plot_type="bar",
        show=False
    )
    plt.title(f"{outcome_name} SHAP Bar (Top15)", fontsize=12, fontweight="bold")
    plt.tight_layout()
    plt.savefig(f"{save_dir}/{outcome_name}_SHAP_Bar.pdf", format="pdf", dpi=300, bbox_inches="tight")
    plt.close()

for outcome_name in ["Hb", "PLT", "WBC_Neut"]:
    shap_values, X_plot = shap_store[outcome_name]
    shap_df = pd.DataFrame(shap_values, columns=X_plot.columns)
    shap_df.to_csv(f"{save_dir}/{outcome_name}_SHAP_Values_All.csv", index=False)

In [None]:
import numpy as np

sample_size = 10000
np.random.seed(42)

X_Building_filtered_raw = X_Building_raw[X_Building_raw["cum_chemo"] <= 10].copy()
X_Building_filtered_scaled = apply_scaler(X_Building_filtered_raw)

np.random.seed(42)
sample_idx_2 = np.random.choice(len(X_Building_filtered_raw), min(sample_size, len(X_Building_filtered_raw)), replace=False)

X_sample_raw = X_Building_filtered_raw.iloc[sample_idx_2]
X_sample_scaled = X_Building_filtered_scaled.iloc[sample_idx_2]

models_shap = {}
for outcome_name in ["Hb", "PLT", "WBC_Neut"]:
    model_path = models_paths[outcome_name]
    models_shap[outcome_name] = joblib.load(model_path)

shap_results = {}
for outcome_name, model in models_shap.items():
    explainer = shap.TreeExplainer(get_base_model(model))
    shap_values = explainer.shap_values(X_sample_scaled)
    if isinstance(shap_values, list):
        shap_values = shap_values[1]
    cum_chemo_idx = X_sample_scaled.columns.get_loc("cum_chemo")
    if outcome_name == "Hb":
        baseline_feature = "base_Hb"
    elif outcome_name == "PLT":
        baseline_feature = "base_PLT"
    else:
        baseline_feature = "base_WBC"
    shap_results[outcome_name] = {
        "cum_chemo_vals": X_sample_raw["cum_chemo"].values,
        "shap_cum_chemo": shap_values[:, cum_chemo_idx],
        "baseline_vals": X_sample_raw[baseline_feature].values,
        "baseline_name": baseline_feature
    }

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for idx, outcome_name in enumerate(["Hb", "PLT", "WBC_Neut"]):
    data = shap_results[outcome_name]
    cum_chemo_vals = data["cum_chemo_vals"]
    shap_cum_chemo = data["shap_cum_chemo"]
    baseline_vals = data["baseline_vals"]
    scatter = axes[idx].scatter(
        cum_chemo_vals,
        shap_cum_chemo,
        c=baseline_vals,
        cmap="RdYlBu_r",
        alpha=0.5,
        s=15,
        edgecolors="none",
        vmin=np.percentile(baseline_vals, 5),
        vmax=np.percentile(baseline_vals, 95)
    )
    cum_chemo_unique = np.sort(np.unique(cum_chemo_vals))
    shap_median = [np.median(shap_cum_chemo[cum_chemo_vals == c]) for c in cum_chemo_unique]
    axes[idx].plot(cum_chemo_unique, shap_median, color="red", linewidth=3, label="Median SHAP", zorder=10)
    axes[idx].set_xlabel("Cumulative Chemotherapy Cycles", fontsize=11, fontweight="bold")
    axes[idx].set_ylabel("SHAP Value" if idx == 0 else "", fontsize=11, fontweight="bold")
    axes[idx].set_title(f"{outcome_name} Suppression", fontsize=12, fontweight="bold", pad=10)
    axes[idx].axhline(y=0, color="gray", linestyle="--", linewidth=1.5, alpha=0.7)
    axes[idx].grid(alpha=0.3, linestyle="--")
    axes[idx].legend(fontsize=9, loc="upper left")
    axes[idx].set_xlim(-0.5, 10.5)
    axes[idx].set_xticks(range(0, 11))
    cbar = plt.colorbar(scatter, ax=axes[idx])
    if outcome_name == "Hb":
        cbar.set_label("Baseline Hb (g/L)", fontsize=10)
    elif outcome_name == "PLT":
        cbar.set_label("Baseline PLT (×10^9/L)", fontsize=10)
    else:
        cbar.set_label("Baseline WBC (×10^9/L)", fontsize=10)

plt.suptitle("Cycle-Dependent Risk Dynamics (Cycles 0-10)", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.savefig(f"{save_dir}/CycleDependence_Combined_Truncated.pdf", format="pdf", dpi=300, bbox_inches="tight")
plt.close()