<a href="https://colab.research.google.com/github/Kenny625819/Applied-Data-Science/blob/main/Figure2_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc, roc_auc_score
from sklearn.utils import resample
from matplotlib.colors import LinearSegmentedColormap
from sklearn.preprocessing import label_binarize

# ------------------------------------------------------------
# Load data
# ------------------------------------------------------------
df_pred = pd.read_csv("escc_oof_predictions.csv")
df_gt   = pd.read_excel("ESCC_3_with_consensus.xlsx")

# Normalize filename: "1.png" → "1"
df_pred["filename"] = df_pred["filename"].astype(str).str.replace(r"\.[A-Za-z0-9]+$", "", regex=True)
df_gt["filename"]   = df_gt["filename"].astype(str)

# Merge
df = df_pred.merge(df_gt[["filename", "ESCC_consensus"]], on="filename", how="inner")

# Allowed labels
valid_labels = ["1b", "1c", "2", "3"]
df = df[df["pred_label"].isin(valid_labels)]
df = df[df["ESCC_consensus"].isin(valid_labels)]

df["true_label"] = df["ESCC_consensus"]
df["pred_label"] = df["pred_label"]

# ------------------------------------------------------------
# Confusion Matrix
# ------------------------------------------------------------
cm = confusion_matrix(df["true_label"], df["pred_label"], labels=valid_labels)

# ------------------------------------------------------------
# Binary ROC (High-grade = 2 or 3)
# ------------------------------------------------------------
df["binary_true"] = df["true_label"].isin(["2","3"]).astype(int)
df["binary_prob"] = df["high_prob"].astype(float)

fpr, tpr, _ = roc_curve(df["binary_true"], df["binary_prob"])
auc_point = auc(fpr, tpr)

# -------- Binary ROC Bootstrap 95% CI --------
boot = []
for i in range(1000):
    sample = resample(df)
    if sample["binary_true"].nunique() < 2:
        continue
    fpr_b, tpr_b, _ = roc_curve(sample["binary_true"], sample["binary_prob"])
    boot.append(auc(fpr_b, tpr_b))

bin_ci_l = np.percentile(boot, 2.5)
bin_ci_u = np.percentile(boot, 97.5)

# ------------------------------------------------------------
# Macro-AUC (4-class) + Bootstrap 95% CI
# ------------------------------------------------------------
y_true_bin = label_binarize(df["true_label"], classes=valid_labels)
y_score_bin = df[["prob_1b","prob_1c","prob_2","prob_3"]].values

macro_auc_point = roc_auc_score(y_true_bin, y_score_bin, average="macro")

macro_boot = []
for i in range(1000):
    sample = resample(df)
    ys = label_binarize(sample["true_label"], classes=valid_labels)
    ps = sample[["prob_1b","prob_1c","prob_2","prob_3"]].values
    try:
        macro_boot.append(roc_auc_score(ys, ps, average="macro"))
    except:
        continue

macro_ci_l = np.percentile(macro_boot, 2.5)
macro_ci_u = np.percentile(macro_boot, 97.5)

# Print for manuscript
print("===== AUCs for Manuscript Text =====")
print(f"Binary AUC (2–3 vs 1b/1c): {auc_point:.3f} (95%CI {bin_ci_l:.3f}–{bin_ci_u:.3f})")
print(f"Macro AUC (4-class): {macro_auc_point:.3f} (95%CI {macro_ci_l:.3f}–{macro_ci_u:.3f})")


# ------------------------------------------------------------
# Custom Blue colormap (ROC青に統一)
# ------------------------------------------------------------
roc_blue_cmap = LinearSegmentedColormap.from_list(
    "roc_blue",
    ["#e6f2ff", "#1f77b4"]
)

# ------------------------------------------------------------
# FIGURE 2 PLOT
# ------------------------------------------------------------
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# ============================================================
# A. Confusion Matrix
# ============================================================
ax = axes[0]
im = ax.imshow(cm, cmap=roc_blue_cmap)

ax.set_title("Confusion Matrix", fontsize=20)
ax.set_xlabel("Predicted label", fontsize=20)
ax.set_ylabel("True label", fontsize=20)

ax.set_xticks(np.arange(len(valid_labels)))
ax.set_xticklabels(valid_labels, fontsize=20)
ax.set_yticks(np.arange(len(valid_labels)))
ax.set_yticklabels(valid_labels, fontsize=20)

for i in range(len(valid_labels)):
    for j in range(len(valid_labels)):
        value = cm[i, j]
        color = "white" if value > cm.max() * 0.5 else "black"
        ax.text(j, i, value, ha="center", va="center",
                fontsize=12, color=color)

plt.colorbar(im, ax=ax)

# ============================================================
# B. ROC Curve (legend WITHOUT CI)
# ============================================================
ax2 = axes[1]

ax2.plot(
    fpr, tpr,
    color="#1f77b4", linewidth=2,
    label=f"AUC = {auc_point:.3f}"
)

ax2.plot([0, 1], [0, 1], color="gray", linestyle="--", linewidth=1)

ax2.set_title("High-grade ESCC Detection", fontsize=20)
ax2.set_xlabel("1 - Specificity", fontsize=20)
ax2.set_ylabel("Sensitivity", fontsize=20)

ax2.tick_params(labelsize=20)
ax2.legend(fontsize=12, loc="lower right")

plt.tight_layout()
plt.savefig("Figure2_ESCC_with_macroCI_blueUnified.png", dpi=600)
plt.close()

print("✓ Figure2 with macro-AUC CI (no CI in legend) saved.")


===== AUCs for Manuscript Text =====
Binary AUC (2–3 vs 1b/1c): 0.744 (95%CI 0.634–0.838)
Macro AUC (4-class): 0.606 (95%CI 0.522–0.687)
✓ Figure2 with macro-AUC CI (no CI in legend) saved.
