<a href="https://colab.research.google.com/github/Kenny625819/Applied-Data-Science/blob/main/Figure3_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ================================================================
#  Multimodal Survival Prediction – Reproducible Analysis Notebook
#  Compatible with MRI_ALLdata_OOF.xlsx
#  Includes:
#      – OOF LightGBM model
#      – ROC curves (formal names, font-size updated)
#      – Calibration plots (title corrected: Calibration plot)
#      – SHAP summary plots (renamed features)
#      – SHAP heatmap (sorted by 3M SHAP)
#      – Performance summary table (Excel)
# ================================================================

!pip install lightgbm shap scikit-learn pandas matplotlib numpy seaborn openpyxl xlsxwriter -q

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import lightgbm as lgb
import shap

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    roc_curve, roc_auc_score, brier_score_loss,
    confusion_matrix, precision_recall_fscore_support
)
from sklearn.isotonic import IsotonicRegression
from scipy.stats import norm
from pathlib import Path

# --------------------------------------------------------------
# Global plotting settings
# --------------------------------------------------------------
plt.rcParams["font.family"] = "DejaVu Sans"
plt.rcParams["axes.unicode_minus"] = False

BLUE, ORANGE, GREEN = "#0072B2", "#E69F00", "#009E73"

OUT = Path("./RESULTS_FIGURES")
OUT.mkdir(exist_ok=True)


# ================================================================
# 1. Load dataset
# ================================================================
df = pd.read_excel("MRI_ALLdata_OOF.xlsx")
df.columns = df.columns.str.strip()


# ================================================================
# 2. Preprocessing helper functions
# ================================================================
def map_sex(v):
    s=str(v).strip().lower()
    if s in ["1","m","male","男","男性"]: return 1
    if s in ["0","f","female","女","女性"]: return 0
    return np.nan

def map_escc(v):
    s=str(v).strip().lower().replace(" ","")
    return {"1b":2,"1c":3,"2":4,"3":5}.get(s, np.nan)

def frankel_bin(v):
    s=str(v).strip().upper()
    if s in ["A","B","C"]: return 0
    if s in ["D","E"]: return 1
    return np.nan

def map_yesno(v):
    s=str(v).strip().lower()
    if s in ["yes","y","true","1","あり","有"]: return 1
    if s in ["no","n","false","0","なし","無"]: return 0
    return np.nan


# ================================================================
# 3. Build processed dataframe
# ================================================================
df_proc = pd.DataFrame({
    "Age": df["Age"],
    "Sex": df["Sex"].apply(map_sex),
    "Number of Spinal Metastases": pd.to_numeric(df["Number of Spinal Metastases"], errors="coerce"),
    "Albumin": pd.to_numeric(df["Serum Albumin"], errors="coerce"),   # ★ Albumin に変更済
    "CRP": pd.to_numeric(df["CRP"], errors="coerce"),
    "ESCC": df["ESCC"].apply(map_escc),
    "ECOG": pd.to_numeric(df["Performance Status (ECOG)"], errors="coerce"),
    "Frankel_bin": df["Frankel Grade"].apply(frankel_bin),
    "Barthel Index": pd.to_numeric(df["Barthel Index (ADL)"], errors="coerce"),
    "Malignancy (Katagiri Score)": pd.to_numeric(df["Malignancy (Katagiri Score)"], errors="coerce"),
    "Visceral Metastasis": df["Visceral Metastasis (Yes=1/No=0)"].apply(map_yesno),
    "BMI": pd.to_numeric(df["Body Mass Index (BMI)"], errors="coerce"),

    "Tokuhashi_binary": (pd.to_numeric(df["Revised Tokuhashi score"], errors="coerce") >= 9).astype(int),
    "Katagiri_binary": (pd.to_numeric(df["New Katagiri score"], errors="coerce") < 7).astype(int),

    "Y_3M": pd.to_numeric(df["3-Month Survival (0=Death, 1=Alive)"], errors="coerce"),
    "Y_6M": pd.to_numeric(df["6-Month Survival (0=Death, 1=Alive)"], errors="coerce"),
    "Y_12M": pd.to_numeric(df["12-Month Survival (0=Death, 1=Alive)"], errors="coerce"),
})

FEATURES = [
    "Age","Sex","Number of Spinal Metastases","Albumin","CRP","ESCC",
    "ECOG","Frankel_bin","Barthel Index","Malignancy (Katagiri Score)",
    "Visceral Metastasis","BMI"
]

# ================================================================
# 4. AUC CI / DeLong / Youden
# ================================================================
def auc_ci_bootstrap(y, s, n_boot=2000):
    rng=np.random.default_rng(42)
    idx=np.arange(len(y))
    aucs=[]
    for _ in range(n_boot):
        b=rng.choice(idx,len(idx),replace=True)
        try: aucs.append(roc_auc_score(y[b],s[b]))
        except: pass
    auc=roc_auc_score(y,s)
    lo,hi=np.percentile(aucs,[2.5,97.5])
    return auc,lo,hi

def delong_test(y,s1,s2):
    order=np.argsort(-s1)
    y=y[order]; s1=s1[order]; s2=s2[order]
    m=int(np.sum(y)); n=len(y)-m
    if m==0 or n==0: return np.nan,np.nan

    def auc_stat(s):
        r=pd.Series(s).rank().values
        return (np.sum(r[:m]) - m*(m+1)/2)/(m*n)

    auc1=auc_stat(s1)
    auc2=auc_stat(s2)
    var=(auc1*(1-auc1)+auc2*(1-auc2))/min(m,n)
    z=(auc1-auc2)/np.sqrt(var)
    p=2*(1-norm.cdf(abs(z)))
    return auc1-auc2,p

def youden_threshold(y,p):
    fpr,tpr,thr=roc_curve(y,p)
    return thr[np.argmax(tpr-fpr)]


# ================================================================
# 5. OOF LightGBM (leakage-free)
# ================================================================
def run_lgb_oof(X, y):
    skf=StratifiedKFold(n_splits=5,shuffle=True,random_state=42)
    params=dict(
        objective="binary", metric="auc",
        learning_rate=0.05, num_leaves=31,
        n_estimators=500, class_weight="balanced",
        random_state=42
    )
    oof=np.zeros(len(y))
    for tr,te in skf.split(X,y):
        model=lgb.LGBMClassifier(**params)
        model.fit(X.iloc[tr],y[tr])
        oof[te]=model.predict_proba(X.iloc[te])[:,1]

    iso=IsotonicRegression(out_of_bounds="clip")
    iso.fit(oof,y)
    calibrated=iso.transform(oof)

    final_model=lgb.LGBMClassifier(**params)
    final_model.fit(X,y)

    return calibrated, final_model


# ================================================================
# 6. Feature rename for SHAP (論文用)
# ================================================================
FEATURE_RENAME = {
    "Malignancy (Katagiri Score)": "Malignancy",
    "ECOG": "ECOG PS",
    "Frankel_bin": "Frankel grade",
    "Albumin": "Albumin"    # ★追加済
}


# ================================================================
# 7. ROC plot（フォント反映・凡例12 pt）
# ================================================================
def plot_roc(ax, y, ai, tok, kat, title):

    fpr_ai,tpr_ai,_ = roc_curve(y,ai)
    fpr_t,tpr_t,_   = roc_curve(y,tok)
    fpr_k,tpr_k,_   = roc_curve(y,kat)

    ax.plot(fpr_ai,tpr_ai,color=BLUE,lw=2.5,
            label=f"AI model (AUC={roc_auc_score(y,ai):.3f})")
    ax.plot(fpr_t,tpr_t,color=ORANGE,ls="--",lw=2,
            label=f"Revised Tokuhashi score (AUC={roc_auc_score(y,tok):.3f})")
    ax.plot(fpr_k,tpr_k,color=GREEN,ls=":",lw=2,
            label=f"New Katagiri score (AUC={roc_auc_score(y,kat):.3f})")

    ax.plot([0,1],[0,1],"--",color="gray",lw=1)

    ax.set_title(title,fontsize=20)
    ax.set_xlabel("1 – Specificity",fontsize=20)
    ax.set_ylabel("Sensitivity",fontsize=20)
    ax.tick_params(axis='both',labelsize=20)

    # ★凡例は小さめ（指示通り12 pt）
    ax.legend(loc="lower right",fontsize=12)


# ================================================================
# 8. Calibration plot（タイトル = Calibration plot）
# ================================================================
def plot_calibration(ax, y, pred, bins=10):

    d=pd.DataFrame({"y":y,"p":pred})
    d["bin"]=pd.qcut(d["p"],q=bins,duplicates="drop")
    g=d.groupby("bin").agg(obs=("y","mean"),pred=("p","mean")).reset_index()

    ax.plot([0,1],[0,1],"--",color="gray",lw=1)
    ax.plot(g["pred"],g["obs"],"o-",color=BLUE)

    ax.set_title("Calibration plot", fontsize=20)
    ax.set_xlabel("Predicted probability", fontsize=20)
    ax.set_ylabel("Observed frequency", fontsize=20)
    ax.tick_params(axis='both',labelsize=20)


# ================================================================
# 9. SHAP summary（rename対応）
# ================================================================
def shap_summary(model, X, title, filename):

    X_disp = X.rename(columns=FEATURE_RENAME)

    explainer = shap.TreeExplainer(model)
    sv = explainer.shap_values(X)
    if isinstance(sv,list):
        sv = sv[1]

    shap.summary_plot(sv, X_disp, show=False, plot_size=(8,6))
    plt.title(title, fontsize=18)
    plt.tight_layout()
    plt.savefig(OUT / filename, dpi=600)
    plt.close()


# ================================================================
# 10. SHAP heatmap（3M順・rename）
# ================================================================
def shap_mean_heatmap(models_dict, X_dict, features, filename):

    display_features=[FEATURE_RENAME.get(f,f) for f in features]
    mean_shap=pd.DataFrame(index=display_features,columns=models_dict.keys())

    for tag in models_dict:
        model=models_dict[tag]
        X=X_dict[tag]

        explainer=shap.TreeExplainer(model)
        sv=explainer.shap_values(X)
        if isinstance(sv,list):
            sv=sv[1]

        for orig,disp,val in zip(features,display_features,np.abs(sv).mean(axis=0)):
            mean_shap.loc[disp,tag]=val

    if "3M" in models_dict:
        mean_shap = mean_shap.sort_values(by="3M",ascending=False)

    plt.figure(figsize=(8,7))
    sns.heatmap(
        mean_shap.astype(float),
        cmap="magma",
        annot=True,
        fmt=".3f",
        cbar_kws={"label":"mean(|SHAP value|)"}
    )
    plt.title("SHAP Heatmap", fontsize=18)
    plt.tight_layout()
    plt.savefig(OUT/filename, dpi=600)
    plt.close()


# ================================================================
# 11. Main loop
# ================================================================
results=[]
models_dict={}
X_dict={}

for tag,ycol in [("3M","Y_3M"),("6M","Y_6M"),("12M","Y_12M")]:

    sub=df_proc.dropna(subset=FEATURES+[ycol])
    X=sub[FEATURES]
    y=sub[ycol].astype(int).values
    X_dict[tag]=X

    p_ai,model = run_lgb_oof(X,y)
    models_dict[tag]=model

    s_tok=sub["Tokuhashi_binary"].astype(float).values
    s_kat=sub["Katagiri_binary"].astype(float).values

    fig,axes=plt.subplots(1,2,figsize=(13,5))
    plot_roc(axes[0],y,p_ai,s_tok,s_kat,f"{tag} Survival")
    plot_calibration(axes[1],y,p_ai)
    plt.tight_layout()
    plt.savefig(OUT/f"Figure3_{tag}.png",dpi=600)
    plt.close()

    shap_summary(model,X,f"SHAP Summary ({tag})",f"SHAP_{tag}.png")

    ai_auc,ai_lo,ai_hi=auc_ci_bootstrap(y,p_ai)
    tk_auc,tk_lo,tk_hi=auc_ci_bootstrap(y,s_tok)
    kt_auc,kt_lo,kt_hi=auc_ci_bootstrap(y,s_kat)

    _,p_tok=delong_test(y,p_ai,s_tok)
    _,p_kat=delong_test(y,p_ai,s_kat)

    thr=youden_threshold(y,p_ai)
    yhat=(p_ai>=thr).astype(int)
    tn,fp,fn,tp=confusion_matrix(y,yhat).ravel()
    sens_ai=tp/(tp+fn)
    spec_ai=tn/(tn+fp)
    _,_,f1_ai,_=precision_recall_fscore_support(y,yhat,average="binary")

    tn_t,fp_t,fn_t,tp_t=confusion_matrix(y,s_tok).ravel()
    tn_k,fp_k,fn_k,tp_k=confusion_matrix(y,s_kat).ravel()

    sens_t=tp_t/(tp_t+fn_t) if (tp_t+fn_t)>0 else np.nan
    spec_t=tn_t/(tn_t+fp_t) if (tn_t+fp_t)>0 else np.nan
    _,_,f1_t,_=precision_recall_fscore_support(y,s_tok,average="binary")

    sens_k=tp_k/(tp_k+fn_k) if (tp_k+fn_k)>0 else np.nan
    spec_k=tn_k/(tn_k+fp_k) if (tn_k+fp_k)>0 else np.nan
    _,_,f1_k,_=precision_recall_fscore_support(y,s_kat,average="binary")

    brier_ai=brier_score_loss(y,p_ai)

    results.append({
        "Timepoint":tag,

        "AI AUC":ai_auc,"AI 95%CI low":ai_lo,"AI 95%CI high":ai_hi,
        "AI Sens":sens_ai,"AI Spec":spec_ai,"AI F1":f1_ai,"AI Brier":brier_ai,

        "Tokuhashi AUC":tk_auc,"Tokuhashi CI low":tk_lo,"Tokuhashi CI high":tk_hi,
        "Tokuhashi Sens":sens_t,"Tokuhashi Spec":spec_t,"Tokuhashi F1":f1_t,

        "Katagiri AUC":kt_auc,"Katagiri CI low":kt_lo,"Katagiri CI high":kt_hi,
        "Katagiri Sens":sens_k,"Katagiri Spec":spec_k,"Katagiri F1":f1_k,

        "p(AI vs Tokuhashi)":p_tok,
        "p(AI vs Katagiri)":p_kat,

        "n":len(sub)
    })

# ================================================================
# 12. SHAP Heatmap（3M順）
# ================================================================
shap_mean_heatmap(models_dict,X_dict,FEATURES,
                  filename="SHAP_heatmap_3M_6M_12M_sorted.png")

# ================================================================
# 13. Save summary Excel
# ================================================================
summary_df=pd.DataFrame(results)
summary_df.to_excel(OUT/"Performance_Summary_OOF.xlsx",index=False)

summary_df


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/175.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m174.1/175.3 kB[0m [31m40.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.3/175.3 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[LightGBM] [Info] Number of positive: 110, number of negative: 30
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000211 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 182
[LightGBM] [Info] Number of data points in the train set: 140, number of used features: 12
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
[LightGBM] [Info] Start training from score 0.000000
[LightGBM] [Info] Number of positive: 110, number of negative: 31
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0

  g=d.groupby("bin").agg(obs=("y","mean"),pred=("p","mean")).reset_index()


[LightGBM] [Info] Number of positive: 84, number of negative: 56
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000054 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 178
[LightGBM] [Info] Number of data points in the train set: 140, number of used features: 12
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=-0.000000
[LightGBM] [Info] Start training from score -0.000000
[LightGBM] [Info] Number of positive: 84, number of negative: 57
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000079 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 183
[LightGBM] [Info] Number of data points in the train set: 141, number of used features: 12
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=-0.000000
[LightGBM] [Info] Start training from score -0.000000
[LightGBM] [Info] Number of po

  g=d.groupby("bin").agg(obs=("y","mean"),pred=("p","mean")).reset_index()


[LightGBM] [Info] Number of positive: 62, number of negative: 78
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000055 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 181
[LightGBM] [Info] Number of data points in the train set: 140, number of used features: 12
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=-0.000000
[LightGBM] [Info] Start training from score -0.000000
[LightGBM] [Info] Number of positive: 63, number of negative: 78
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000057 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 180
[LightGBM] [Info] Number of data points in the train set: 141, number of used features: 12
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
[LightGBM] [Info] Start training from score 0.000000
[LightGBM] [Info] Number of posi

  g=d.groupby("bin").agg(obs=("y","mean"),pred=("p","mean")).reset_index()


Unnamed: 0,Timepoint,AI AUC,AI 95%CI low,AI 95%CI high,AI Sens,AI Spec,AI F1,AI Brier,Tokuhashi AUC,Tokuhashi CI low,...,Tokuhashi F1,Katagiri AUC,Katagiri CI low,Katagiri CI high,Katagiri Sens,Katagiri Spec,Katagiri F1,p(AI vs Tokuhashi),p(AI vs Katagiri),n
0,3M,0.739035,0.655647,0.8145,0.557971,0.815789,0.693694,0.148258,0.605454,0.546883,...,0.441989,0.66762,0.578409,0.756152,0.782609,0.552632,0.821293,1.736238e-06,3.421862e-05,176
1,6M,0.824279,0.761131,0.881402,0.771429,0.746479,0.794118,0.162532,0.657545,0.605672,...,0.527027,0.693897,0.625553,0.76011,0.866667,0.521127,0.791304,1.383791e-10,1.205139e-09,176
2,12M,0.84393,0.786434,0.895383,0.782051,0.744898,0.743902,0.155185,0.695055,0.634361,...,0.595041,0.702643,0.642851,0.757035,0.935897,0.469388,0.719212,2.255753e-10,1.25848e-09,176
