In [None]:
# -*- coding: utf-8 -*-
"""
Full pipeline script (training optional + post-processing priority + SHAP)
- Set FORCE_TRAIN=True on first run to generate preprocessor/weights/parameters/assets.
- Afterwards set FORCE_TRAIN=False to perform only post-processing and explanation.
- When DATA_PATH, feature slices, MONO_MODE, hyperparameters, etc. change,
  the script will detect the difference and retrain once to ensure consistency.
"""

import os
import json
import pickle
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
import shap
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    classification_report,
    mean_absolute_error,
    f1_score,
    accuracy_score,
    recall_score,
    confusion_matrix
)
from sklearn.isotonic import IsotonicRegression

# ===================== Basic Configuration (modify as needed) =====================
# Data / Output paths (example uses GI dataset)
DATA_PATH = "./data/extracted_Gradually_Increasing.csv"  # Replace with other datasets (e.g. FD), but adjust slices & monotonic direction
OUT_DIR   = "./results/Gradually_Increasing"
os.makedirs(OUT_DIR, exist_ok=True)

# Set True on first run to save assets; False afterwards for post-processing & explanation only
FORCE_TRAIN = False

# Feature column slices (GI: 9:417; FD: 9:76)
PROTEIN_SLICE   = slice(9, 417)     # GI
COVARIATE_SLICE = slice(2, 9)

# Monotonic direction: FD → "dec" (higher protein → milder condition), GI → "inc" (higher → worse)
MONO_MODE = "inc"                   # GI

# Model & training hyperparameters
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
tf.random.set_seed(RANDOM_STATE)
NUM_CLASSES = 5
TAU = 3.5
SIGMA_LDL = 1.2
BATCH_SIZE = 256
EPOCHS_GS  = 120
EPOCHS_FIN = 120
N_ENSEMBLE = 3

# Combined objective weights (overall quality dominant; middle classes recall has small weight)
OBJ_WEIGHTS = dict(acc=0.45, qwk=0.35, macro=0.10, midrec=0.10)

# ===================== File Paths (cache / assets) =====================
VAL_PROBS_NPY  = os.path.join(OUT_DIR, "cache_val_probs.npy")
TEST_PROBS_NPY = os.path.join(OUT_DIR, "cache_test_probs.npy")
VAL_Y_NPY      = os.path.join(OUT_DIR, "cache_val_y.npy")
TEST_Y_NPY     = os.path.join(OUT_DIR, "cache_test_y.npy")
BEST_SEL_JSON  = os.path.join(OUT_DIR, "best_selection.json")

PREPROC_DIR    = os.path.join(OUT_DIR, "preproc")
os.makedirs(PREPROC_DIR, exist_ok=True)
IMP_P_PKL      = os.path.join(PREPROC_DIR, "imp_p.pkl")
IMP_C_PKL      = os.path.join(PREPROC_DIR, "imp_c.pkl")
SC_P_PKL       = os.path.join(PREPROC_DIR, "sc_p.pkl")
SC_C_PKL       = os.path.join(PREPROC_DIR, "sc_c.pkl")

MODEL_PARAMS_JSON = os.path.join(OUT_DIR, "best_model_params.json")
ENS0_WEIGHTS_H5   = os.path.join(OUT_DIR, "ens0.weights.h5")

IDX_DIR       = os.path.join(OUT_DIR, "indices")
os.makedirs(IDX_DIR, exist_ok=True)
IDX_TRAIN_NPY = os.path.join(IDX_DIR, "idx_train.npy")
IDX_VAL_NPY   = os.path.join(IDX_DIR, "idx_val.npy")
IDX_TEST_NPY  = os.path.join(IDX_DIR, "idx_test.npy")

FEATURE_NAMES_JSON = os.path.join(OUT_DIR, "feature_names.json")
META_JSON    = os.path.join(OUT_DIR, "cache_meta.json")

# ===================== Utility Functions =====================
def display_pred_dist(name, y):
    u, c = np.unique(y, return_counts=True)
    print(name, dict(zip(u, c)))

def gaussian_label_distribution(y, C, sigma=1.2):
    y = y.astype(np.float32)
    ks = np.arange(C, dtype=np.float32)[None, :]
    yi = y[:, None]
    T = np.exp(-(ks - yi)**2 / (2.0 * (sigma**2)))
    T = T / (T.sum(axis=1, keepdims=True) + 1e-8)
    return T.astype(np.float32)

def quadratic_weighted_kappa(y_true, y_pred, C=5):
    W = np.zeros((C, C), dtype=np.float64)
    for i in range(C):
        for j in range(C):
            W[i, j] = ((i - j) / (C - 1)) ** 2
    CM = np.zeros((C, C), dtype=np.float64)
    for a, b in zip(y_true, y_pred):
        CM[a, b] += 1
    n = CM.sum()
    if n == 0:
        return 0.0
    O = CM / n
    hist_true = O.sum(axis=1)
    hist_pred = O.sum(axis=0)
    E = np.outer(hist_true, hist_pred)
    denom = (W * E).sum() + 1e-12
    return 1.0 - (W * O).sum() / denom

def softmax_power(p, gamma):
    p = np.clip(p, 1e-8, 1.0)
    z = p ** gamma
    return z / (z.sum(axis=1, keepdims=True) + 1e-8)

def digitize_with_cuts(ey, cuts):
    bins = [-np.inf] + list(cuts) + [np.inf]
    return np.digitize(ey, bins) - 1

def combo_objective(y_true, y_pred, weights=OBJ_WEIGHTS):
    acc   = accuracy_score(y_true, y_pred)
    qwk   = quadratic_weighted_kappa(y_true, y_pred, C=NUM_CLASSES)
    macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
    rec2  = recall_score(y_true, y_pred, labels=[2], average='macro', zero_division=0)
    rec3  = recall_score(y_true, y_pred, labels=[3], average='macro', zero_division=0)
    midrec = 0.5 * (rec2 + rec3)
    obj = (weights['acc']   * acc +
           weights['qwk']   * qwk +
           weights['macro'] * macro +
           weights['midrec']* midrec)
    return obj, acc, qwk, macro, midrec

def make_dataset(X, y, T, batch_size, training=True, mixup_alpha=0.3):
    ds = tf.data.Dataset.from_tensor_slices((X, (y, T)))
    if training:
        ds = ds.shuffle(len(X), seed=RANDOM_STATE, reshuffle_each_iteration=True)
    ds = ds.batch(batch_size, drop_remainder=training)
    if training and mixup_alpha > 0:
        def _mix(xs, yts):
            ys, Ts = yts
            bs = tf.shape(xs)[0]
            idx = tf.random.shuffle(tf.range(bs))
            lam1 = tf.random.gamma([bs,1], mixup_alpha, beta=1.0)
            lam2 = tf.random.gamma([bs,1], mixup_alpha, beta=1.0)
            lam  = tf.cast(lam1 / (lam1 + lam2), tf.float32)
            xs2, Ts2 = tf.gather(xs, idx, axis=0), tf.gather(Ts, idx, axis=0)
            xs_m  = lam * xs + (1 - lam) * xs2
            Ts_m  = lam * Ts + (1 - lam) * Ts2
            return xs_m, (ys, Ts_m)
        ds = ds.map(_mix, num_parallel_calls=tf.data.AUTOTUNE)
    return ds.prefetch(tf.data.AUTOTUNE)

# ===================== Model Definition (monotonic direction switchable) =====================
class OrdinalLDL(tf.keras.Model):
    def __init__(self, input_dim, num_classes, layer_sizes, dropout_rate,
                 l2_lambda, protein_indices,
                 tau=TAU,
                 ey_lambda=0.8, emd_lambda=1.0, kl_lambda=1.2,
                 entropy_lambda=0.10, score_l2_lambda=1.0e-3,
                 mono_lambda=5e-4, alpha_margin=0.5, margin_lambda=2e-2,
                 train_prior=None, gaussian_noise_std=0.01, mono_mode="inc"):
        super().__init__()
        self.num_classes = num_classes
        self.tau = tf.constant(tau, tf.float32)
        self.mono_mode = mono_mode  # "dec" / "inc"
        self.in_noise = tf.keras.layers.GaussianNoise(gaussian_noise_std)
        self.backbone = tf.keras.Sequential(name="backbone")
        for s in layer_sizes:
            self.backbone.add(tf.keras.layers.Dense(
                s, activation='relu',
                kernel_regularizer=tf.keras.regularizers.l2(l2_lambda)))
            self.backbone.add(tf.keras.layers.Dropout(dropout_rate))
        self.score_head = tf.keras.layers.Dense(1, activation=None,
                kernel_regularizer=tf.keras.regularizers.l2(l2_lambda), name="score")
        self.alpha_raw = tf.Variable(tf.random.normal([num_classes-1], stddev=0.1),
                                     trainable=True, name="alpha_raw")
        self.protein_indices = tf.constant(protein_indices, dtype=tf.int32)
        self.ey_lambda = tf.constant(ey_lambda, tf.float32)
        self.emd_lambda = tf.constant(emd_lambda, tf.float32)
        self.kl_lambda = tf.constant(kl_lambda, tf.float32)
        self.entropy_lambda = tf.constant(entropy_lambda, tf.float32)
        self.score_l2_lambda = tf.constant(score_l2_lambda, tf.float32)
        self.mono_lambda = tf.constant(mono_lambda, tf.float32)
        self.margin_lambda = tf.constant(margin_lambda, tf.float32)
        self.alpha_margin  = tf.constant(alpha_margin, tf.float32)
        self.train_prior = tf.constant(train_prior, tf.float32) if train_prior is not None else None

    def call(self, inputs, training=False):
        x = self.in_noise(inputs, training=training)
        h = self.backbone(x, training=training)
        score = self.score_head(h)
        return score

    def get_alphas(self):
        inc = tf.nn.softplus(self.alpha_raw) + 1e-6
        alpha = tf.cumsum(inc)
        alpha = alpha - tf.reduce_mean(alpha)
        return alpha

    def class_probs_from_score(self, score, alpha):
        logits = (score - alpha[tf.newaxis, :]) / self.tau
        probs_gt = tf.sigmoid(logits)               # P(y > k)
        p0 = 1 - probs_gt[:, :1]
        p_mid = probs_gt[:, :-1] - probs_gt[:, 1:] if probs_gt.shape[1] > 1 else tf.zeros((tf.shape(score)[0], 0), score.dtype)
        plast = probs_gt[:, -1:]
        p = tf.concat([p0, p_mid, plast], axis=1)
        p = tf.clip_by_value(p, 1e-7, 1.0)
        p = p / tf.reduce_sum(p, axis=1, keepdims=True)
        return p

    @staticmethod
    def _cumdist(p):  # (n, C)
        return tf.cumsum(p, axis=1)

    def train_step(self, data):
        x, (y_int, T_soft) = data
        x = tf.cast(x, tf.float32)
        y_int = tf.cast(tf.reshape(y_int, (-1,1)), tf.float32)
        T_soft = tf.cast(T_soft, tf.float32)
        with tf.GradientTape() as tape_params:
            with tf.GradientTape() as tape_x:
                tape_x.watch(x)
                score = self(x, training=True)
                alpha = self.get_alphas()
                p = self.class_probs_from_score(score, alpha)
                classes = tf.cast(tf.range(self.num_classes), tf.float32)[tf.newaxis,:]
                ey = tf.reduce_sum(p * classes, axis=1, keepdims=True)

                CE = -tf.reduce_mean(tf.reduce_sum(T_soft * tf.math.log(p + 1e-8), axis=1))
                Ft = self._cumdist(T_soft)
                Fp = self._cumdist(p)
                EMD2 = tf.reduce_mean(tf.reduce_sum(tf.square(Fp - Ft), axis=1))
                MSE_ey = tf.reduce_mean(tf.square(ey - y_int))
                KL = 0.0
                if self.train_prior is not None:
                    mean_pred = tf.reduce_mean(p, axis=0)
                    KL = tf.reduce_sum(mean_pred * (tf.math.log(mean_pred + 1e-8) - tf.math.log(self.train_prior + 1e-8)))
                Ent = -tf.reduce_mean(tf.reduce_sum(p * tf.math.log(p + 1e-8), axis=1))

                score_mean = tf.reduce_mean(score)
                grads_x = tape_x.gradient(score_mean, x)
                prot_grads = tf.gather(grads_x, indices=self.protein_indices, axis=1)
                if self.mono_mode == "dec":
                    mono_pen = tf.reduce_mean(tf.nn.relu(prot_grads)**2)
                else:
                    mono_pen = tf.reduce_mean(tf.nn.relu(-prot_grads)**2)

                deltas = alpha[1:] - alpha[:-1]
                margin_pen = tf.reduce_mean(tf.nn.relu(self.alpha_margin - deltas)**2)
                score_l2 = tf.reduce_mean(tf.square(score))
                reg_loss = tf.add_n(self.losses) if self.losses else 0.0

                loss = (CE + self.emd_lambda*EMD2 + self.ey_lambda*MSE_ey + self.kl_lambda*KL
                        - self.entropy_lambda*Ent + self.mono_lambda*mono_pen
                        + self.margin_lambda*margin_pen + self.score_l2_lambda*score_l2 + reg_loss)

        grads = tape_params.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        y_pred = tf.argmax(p, axis=1, output_type=tf.int32)
        acc = tf.reduce_mean(tf.cast(tf.equal(y_pred, tf.cast(y_int, tf.int32)), tf.float32))
        return {"loss": loss, "accuracy": acc}

    def test_step(self, data):
        x, (y_int, T_soft) = data
        x = tf.cast(x, tf.float32)
        y_int = tf.cast(tf.reshape(y_int, (-1,1)), tf.float32)
        T_soft = tf.cast(T_soft, tf.float32)
        score = self(x, training=False)
        alpha = self.get_alphas()
        p = self.class_probs_from_score(score, alpha)
        classes = tf.cast(tf.range(self.num_classes), tf.float32)[tf.newaxis,:]
        ey = tf.reduce_sum(p*classes, axis=1, keepdims=True)
        CE = -tf.reduce_mean(tf.reduce_sum(T_soft * tf.math.log(p + 1e-8), axis=1))
        Ft = self._cumdist(T_soft)
        Fp = self._cumdist(p)
        EMD2 = tf.reduce_mean(tf.reduce_sum(tf.square(Fp - Ft), axis=1))
        MSE_ey = tf.reduce_mean(tf.square(ey - y_int))
        y_pred = tf.argmax(p, axis=1, output_type=tf.int32)
        acc = tf.reduce_mean(tf.cast(tf.equal(y_pred, tf.cast(y_int, tf.int32)), tf.float32))
        return {"loss": CE + self.emd_lambda*EMD2 + self.ey_lambda*MSE_ey, "accuracy": acc}

    def predict_class_probs(self, X):
        score = self(X, training=False)
        alpha = self.get_alphas()
        return self.class_probs_from_score(score, alpha).numpy()

def predict_probs_and_ey(model, X, tau=TAU):
    score = model(X.astype(np.float32), training=False)
    alpha = model.get_alphas()
    logits = (score - alpha[tf.newaxis, :]) / tau
    probs_gt = tf.sigmoid(logits)
    p0 = 1 - probs_gt[:, :1]
    p_mid = probs_gt[:, :-1] - probs_gt[:, 1:] if probs_gt.shape[1] > 1 else tf.zeros((tf.shape(score)[0], 0), score.dtype)
    plast = probs_gt[:, -1:]
    p = tf.concat([p0, p_mid, plast], axis=1).numpy()
    p = p / (p.sum(axis=1, keepdims=True) + 1e-8)
    ey = (p * np.arange(NUM_CLASSES)[None, :]).sum(axis=1)
    return p, ey

# ===================== Calibration Search (combined objective + extreme-class friendly) =====================
def _precompute_sorted(ey, y):
    order = np.argsort(ey)
    ey_s = ey[order]
    y_s = y[order]
    C = NUM_CLASSES
    onehot = np.eye(C, dtype=np.int32)[y_s]
    cum = np.vstack([np.zeros((1,C), dtype=np.int32), np.cumsum(onehot, axis=0)])
    return ey_s, y_s, cum, order

def _indices_from_quantiles(ey_s, qs=(0.2, 0.4, 0.6, 0.8)):
    n = len(ey_s)
    cuts = [np.quantile(ey_s, q) for q in qs]
    idxs = [int(np.searchsorted(ey_s, c, side='right')) for c in cuts]
    idxs = np.clip(np.maximum.accumulate(idxs), 1, n-1)
    return tuple(idxs)

def _cd_cutpoints(ey, y, obj_func, max_iter=4, win_frac=0.18, step_frac=0.02, min_gap_frac=0.01):
    ey_s, y_s, cum, order = _precompute_sorted(ey, y)
    n = len(ey_s)
    GAP = max(1, int(min_gap_frac*n))
    STEP = max(1, int(step_frac*n))
    WIN0 = max(1, int(win_frac*n))
    i1,i2,i3,i4 = _indices_from_quantiles(ey_s)

    def _eval_from_idxs(i1,i2,i3,i4):
        cuts = []
        for i in (i1,i2,i3,i4):
            left  = ey_s[max(i-1, 0)]
            right = ey_s[min(i,   n-1)]
            cuts.append(0.5*(left+right))
        y_hat = digitize_with_cuts(ey, cuts)
        return obj_func(y, y_hat), tuple(cuts)

    best_metrics, best_cuts = _eval_from_idxs(i1,i2,i3,i4)
    best_obj = best_metrics[0]
    win = WIN0
    improved = True
    it = 0
    idxs = [i1,i2,i3,i4]
    while improved and it < max_iter:
        improved = False
        for k in range(4):
            lo = (idxs[k-1]+GAP) if k > 0 else 1
            hi = (idxs[k+1]-GAP) if k < 3 else (n-1)
            lo = max(lo, idxs[k]-win)
            hi = min(hi, idxs[k]+win)
            best_local = best_obj
            best_loc = idxs[k]
            best_local_cuts = best_cuts
            for idx in range(lo, hi+1, STEP):
                cur = idxs.copy()
                cur[k] = idx
                metrics, cuts = _eval_from_idxs(*cur)
                if metrics[0] > best_local:
                    best_local = metrics[0]
                    best_loc = idx
                    best_local_cuts = cuts
            if best_local > best_obj + 1e-12:
                best_obj = best_local
                idxs[k] = best_loc
                best_cuts = best_local_cuts
                improved = True
        win = max(1, int(win*0.7))
        it += 1
    return best_cuts, best_obj

def extreme_objective(y_true, y_pred):
    acc   = accuracy_score(y_true, y_pred)
    macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
    rec0  = recall_score(y_true, y_pred, labels=[0], average='macro', zero_division=0)
    rec4  = recall_score(y_true, y_pred, labels=[4], average='macro', zero_division=0)
    ext   = 0.5 * (rec0 + rec4)
    obj = 0.30 * acc + 0.20 * macro + 0.50 * ext
    return obj, acc, np.nan, macro, ext

def search_calibration_ext(y_val, p_val, weights=OBJ_WEIGHTS):
    best = {"obj": -1.0}
    classes = np.arange(NUM_CLASSES)

    # Phase 1: combined objective (γ, a, b → cutpoints)
    for g in [0.9, 1.0, 1.1, 1.2]:
        p_g  = softmax_power(p_val, g)
        ey_g = (p_g * classes[None,:]).sum(axis=1)
        a_ls, b_ls = np.polyfit(ey_g, y_val, 1)
        a_ls = float(np.clip(a_ls, 0.85, 1.20))
        b_ls = float(np.clip(b_ls, -0.30, 0.30))
        for a in [a_ls-0.05, a_ls, a_ls+0.05]:
            for b in [b_ls-0.10, b_ls, b_ls+0.10]:
                ey_ab = a*ey_g + b
                def _obj_combo(y_t, y_h): return combo_objective(y_t, y_h, weights)
                cuts, _ = _cd_cutpoints(ey_ab, y_val, _obj_combo, max_iter=4)
                y_hat = digitize_with_cuts(ey_ab, cuts)
                obj, acc, qwk, macro, midrec = combo_objective(y_val, y_hat, weights)
                if obj > best["obj"]:
                    best = {
                        "obj": obj, "acc": acc, "qwk": qwk, "macro": macro, "midrec": midrec,
                        "mode": "cut", "gamma": g, "a": a, "b": b, "cuts": cuts
                    }

    # Phase 2: extreme-class friendly fine-tuning (around best point)
    g0, a0, b0 = best["gamma"], best["a"], best["b"]
    best_ext = {"obj": -1.0}
    for dg in [0.95, 1.0, 1.05]:
        for da in [0.95, 1.0, 1.05]:
            for db in [-0.10, 0.0, 0.10]:
                ey_ab = (a0*da) * ((softmax_power(p_val, g0*dg) * classes[None,:]).sum(axis=1)) + (b0 + db)
                cuts, _ = _cd_cutpoints(ey_ab, y_val, extreme_objective, max_iter=3)
                y_hat = digitize_with_cuts(ey_ab, cuts)
                obj, acc, _, macro, ext = extreme_objective(y_val, y_hat)
                if obj > best_ext["obj"]:
                    best_ext = {
                        "obj": obj, "acc": acc, "macro": macro, "ext": ext,
                        "gamma": g0*dg, "a": a0*da, "b": b0+db, "cuts": cuts
                    }

    # Replacement condition (mild): ext gain ≥0.06 and acc drop ≤0.02
    p_g0 = softmax_power(p_val, best["gamma"])
    ey_g0 = (p_g0 * classes[None,:]).sum(axis=1)
    y_hat0 = digitize_with_cuts(best["a"]*ey_g0 + best["b"], best["cuts"])
    _, _, _, _, ext0 = extreme_objective(y_val, y_hat0)

    p_g1 = softmax_power(p_val, best_ext["gamma"])
    ey_g1 = (p_g1 * classes[None,:]).sum(axis=1)
    y_hat1 = digitize_with_cuts(best_ext["a"]*ey_g1 + best_ext["b"], best_ext["cuts"])
    acc1 = accuracy_score(y_val, y_hat1)

    if (best_ext["ext"] - ext0 >= 0.06) and (acc1 >= best["acc"] - 0.02):
        best = {
            "obj": best["obj"], "acc": acc1,
            "qwk": quadratic_weighted_kappa(y_val, y_hat1, C=NUM_CLASSES),
            "macro": f1_score(y_val, y_hat1, average='macro', zero_division=0),
            "midrec": 0.5 * (
                recall_score(y_val, y_hat1, labels=[2], average='macro', zero_division=0) +
                recall_score(y_val, y_hat1, labels=[3], average='macro', zero_division=0)
            ),
            "mode": "cut",
            "gamma": best_ext["gamma"],
            "a": best_ext["a"],
            "b": best_ext["b"],
            "cuts": best_ext["cuts"]
        }
    return best

def apply_selection(sel, p_mat):
    p_g = softmax_power(p_mat, sel['gamma'])
    ey = (p_g * np.arange(NUM_CLASSES)[None,:]).sum(axis=1)
    ey = sel['a'] * ey + sel['b']
    if sel.get("mode", "cut") == "cut":
        return digitize_with_cuts(ey, sel['cuts'])
    else:
        return np.rint(sel['iso'].predict(ey)).astype(int).clip(0, NUM_CLASSES-1)

# ===================== Cache Fingerprint (meta) & Validity Check =====================
meta_now = dict(
    data_path=os.path.abspath(DATA_PATH),
    protein_slice=repr(PROTEIN_SLICE),
    cov_slice=repr(COVARIATE_SLICE),
    mono_mode=MONO_MODE,
    num_classes=int(NUM_CLASSES),
    tau=float(TAU),
    random_state=int(RANDOM_STATE),
)

def assets_exist():
    need = [
        IMP_P_PKL, IMP_C_PKL, SC_P_PKL, SC_C_PKL,
        MODEL_PARAMS_JSON, ENS0_WEIGHTS_H5,
        IDX_TRAIN_NPY, IDX_VAL_NPY, IDX_TEST_NPY,
        FEATURE_NAMES_JSON
    ]
    return all(os.path.exists(p) for p in need)

ready_in_memory     = all(k in globals() for k in ["p_val_avg","p_test_avg","y_val","y_test"])
ready_probs_on_disk = all(os.path.exists(p) for p in [VAL_PROBS_NPY, TEST_PROBS_NPY, VAL_Y_NPY, TEST_Y_NPY])
ready_assets_on_disk = assets_exist()

# Fingerprint gate
if ready_probs_on_disk or ready_assets_on_disk:
    if not os.path.exists(META_JSON):
        print("[CACHE] Found old cache but missing meta fingerprint: will ignore cache and retrain.")
        ready_in_memory = ready_probs_on_disk = ready_assets_on_disk = False
    else:
        meta_old = json.load(open(META_JSON, "r", encoding="utf-8"))
        if meta_old != meta_now:
            print("[CACHE] Cache fingerprint does not match current config: ignoring old cache and retraining.")
            ready_in_memory = ready_probs_on_disk = ready_assets_on_disk = False

if FORCE_TRAIN:
    print("[FORCE] Force retraining to generate/update assets and probability cache.")
    ready_in_memory = ready_probs_on_disk = ready_assets_on_disk = False

# ===================== Data Branch: Train or Load from Cache =====================
if ready_in_memory:
    print("[INFO] Reusing probabilities and labels from current session (no retrain)")
    p_val_avg = globals()["p_val_avg"]
    p_test_avg = globals()["p_test_avg"]
    y_val = globals()["y_val"]
    y_test = globals()["y_test"]

elif ready_probs_on_disk and ready_assets_on_disk:
    print("[INFO] Loading probabilities, labels and assets from disk cache (no retrain)")
    p_val_avg  = np.load(VAL_PROBS_NPY)
    p_test_avg = np.load(TEST_PROBS_NPY)
    y_val      = np.load(VAL_Y_NPY)
    y_test     = np.load(TEST_Y_NPY)

else:
    print("[INFO] No valid cache found → performing training to generate probabilities and assets (reusable next time)")

    # Load data & feature names
    df = pd.read_csv(DATA_PATH)
    protein_names    = df.columns[PROTEIN_SLICE].tolist()
    covariates_names = df.columns[COVARIATE_SLICE].tolist()
    all_feature_names = protein_names + covariates_names
    num_proteins = len(protein_names)

    # Index-level split, save indices for reproducibility
    idx_all = np.arange(len(df))
    y_all = df.iloc[:, 1].values.astype(np.int32)

    idx_train, idx_test = train_test_split(
        idx_all, test_size=0.2, random_state=RANDOM_STATE, stratify=y_all
    )
    y_train_tmp = y_all[idx_train]
    idx_subtrain, idx_val = train_test_split(
        idx_train, test_size=0.3, random_state=RANDOM_STATE, stratify=y_train_tmp
    )

    # Save split indices
    np.save(IDX_TRAIN_NPY, idx_train)
    np.save(IDX_VAL_NPY,   idx_val)
    np.save(IDX_TEST_NPY,  idx_test)

    # Extract matrices
    y_train = y_all[idx_train]
    y_val   = y_all[idx_val]
    y_test  = y_all[idx_test]
    Xp_train = df.iloc[idx_train, PROTEIN_SLICE].values.astype(np.float32)
    Xc_train = df.iloc[idx_train, COVARIATE_SLICE].values.astype(np.float32)
    Xp_val   = df.iloc[idx_val,   PROTEIN_SLICE].values.astype(np.float32)
    Xc_val   = df.iloc[idx_val,   COVARIATE_SLICE].values.astype(np.float32)
    Xp_test  = df.iloc[idx_test,  PROTEIN_SLICE].values.astype(np.float32)
    Xc_test  = df.iloc[idx_test,  COVARIATE_SLICE].values.astype(np.float32)

    # Preprocessors (fit on train), save
    imp_p = SimpleImputer(strategy='mean')
    imp_c = SimpleImputer(strategy='mean')
    Xp_tr = imp_p.fit_transform(Xp_train)
    Xp_val = imp_p.transform(Xp_val)
    Xp_te = imp_p.transform(Xp_test)
    Xc_tr = imp_c.fit_transform(Xc_train)
    Xc_val = imp_c.transform(Xc_val)
    Xc_te = imp_c.transform(Xc_test)

    sc_p = StandardScaler()
    sc_c = StandardScaler()
    Xp_tr = sc_p.fit_transform(Xp_tr)
    Xp_val = sc_p.transform(Xp_val)
    Xp_te = sc_p.transform(Xp_te)
    Xc_tr = sc_c.fit_transform(Xc_tr)
    Xc_val = sc_c.transform(Xc_val)
    Xc_te = sc_c.transform(Xc_te)

    # Concatenate
    X_train = np.hstack([Xp_tr, Xc_tr]).astype(np.float32)
    X_val_m = np.hstack([Xp_val, Xc_val]).astype(np.float32)
    X_test  = np.hstack([Xp_te, Xc_te]).astype(np.float32)

    # Save preprocessors
    joblib.dump(imp_p, IMP_P_PKL)
    joblib.dump(imp_c, IMP_C_PKL)
    joblib.dump(sc_p,  SC_P_PKL)
    joblib.dump(sc_c,  SC_C_PKL)

    # Save feature names
    json.dump(
        {"protein": protein_names, "covariates": covariates_names},
        open(FEATURE_NAMES_JSON, "w", encoding="utf-8"),
        ensure_ascii=False,
        indent=2
    )

    # Prior (boost middle classes)
    prior_counts = np.array([(y_train == k).sum() for k in range(NUM_CLASSES)], dtype=np.float32)
    train_prior = prior_counts / prior_counts.sum()
    mid_prior = np.array([0.08, 0.22, 0.40, 0.22, 0.08], dtype=np.float32)
    mid_prior /= mid_prior.sum()
    mix_prior = 0.5 * train_prior + 0.5 * mid_prior

    T_train = gaussian_label_distribution(y_train, NUM_CLASSES, sigma=SIGMA_LDL)
    T_val   = gaussian_label_distribution(y_val,   NUM_CLASSES, sigma=SIGMA_LDL)
    T_test  = gaussian_label_distribution(y_test,  NUM_CLASSES, sigma=SIGMA_LDL)

    # Grid search (simplified — no mixup during GS for stability)
    layer_configs = [[512,256,128], [256,128]]
    dropout_rates = [0.2]
    l2_lambdas    = [1e-3]
    learning_rates= [1e-3, 5e-4, 3e-4]
    callbacks_gs = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4,
                                             min_lr=1e-6, verbose=1),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10,
                                         restore_best_weights=True, verbose=1)
    ]
    best_params = None
    best_obj = -1.0
    best_sel = None
    protein_indices_arr = np.arange(len(protein_names))

    for layer_sizes in layer_configs:
        for dr in dropout_rates:
            for l2l in l2_lambdas:
                for lr in learning_rates:
                    print(f"[GS] layers={layer_sizes}, dropout={dr}, l2={l2l}, lr={lr}")
                    tf.keras.backend.clear_session()
                    model = OrdinalLDL(
                        input_dim=X_train.shape[1],
                        num_classes=NUM_CLASSES,
                        layer_sizes=layer_sizes,
                        dropout_rate=dr,
                        l2_lambda=l2l,
                        protein_indices=protein_indices_arr,
                        tau=TAU,
                        emd_lambda=1.0,
                        entropy_lambda=0.10,
                        kl_lambda=1.2,
                        train_prior=mix_prior,
                        mono_mode=MONO_MODE
                    )
                    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr, clipvalue=1.0))
                    model.fit(
                        make_dataset(X_train, y_train, T_train, BATCH_SIZE, training=True, mixup_alpha=0.0),
                        epochs=EPOCHS_GS,
                        validation_data=make_dataset(X_val_m, y_val, T_val, BATCH_SIZE, training=False),
                        callbacks=callbacks_gs,
                        verbose=0
                    )
                    p_val_tmp, ey_val_tmp = predict_probs_and_ey(model, X_val_m)
                    sel = search_calibration_ext(y_val, p_val_tmp, weights=OBJ_WEIGHTS)
                    if sel['obj'] > best_obj:
                        best_obj = sel['obj']
                        best_params = {'layers':layer_sizes, 'dropout':dr, 'l2':l2l, 'lr':lr}
                        best_sel = sel

    print("[GS] Best configuration:", best_params, "| sel=", best_sel)

    # Save best architecture / hyperparameters (used for model reconstruction & SHAP)
    json.dump(
        {"best_params": best_params, "mono_mode": MONO_MODE, "tau": TAU, "num_classes": NUM_CLASSES},
        open(MODEL_PARAMS_JSON, "w", encoding="utf-8"),
        ensure_ascii=False,
        indent=2
    )

    # Final training + ensemble (batch-level MixUp)
    callbacks_final = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5,
                                             min_lr=1e-6, verbose=1),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=12,
                                         restore_best_weights=True, verbose=1)
    ]
    models = []
    train_ds = make_dataset(X_train, y_train, T_train, BATCH_SIZE, training=True, mixup_alpha=0.3)
    val_ds   = make_dataset(X_val_m, y_val,   T_val,   BATCH_SIZE, training=False)

    for seed in range(N_ENSEMBLE):
        tf.keras.backend.clear_session()
        np.random.seed(RANDOM_STATE + seed)
        tf.random.set_seed(RANDOM_STATE + seed)
        m = OrdinalLDL(
            input_dim=X_train.shape[1],
            num_classes=NUM_CLASSES,
            layer_sizes=best_params['layers'],
            dropout_rate=best_params['dropout'],
            l2_lambda=best_params['l2'],
            protein_indices=protein_indices_arr,
            tau=TAU,
            emd_lambda=1.0,
            entropy_lambda=0.10,
            kl_lambda=1.2,
            train_prior=mix_prior,
            mono_mode=MONO_MODE
        )
        m.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=best_params['lr'], clipvalue=1.0))
        m.fit(train_ds, epochs=EPOCHS_FIN, validation_data=val_ds,
              callbacks=callbacks_final, verbose=2)
        models.append(m)

    # Save first ensemble member's weights (for SHAP)
    models[0].save_weights(ENS0_WEIGHTS_H5)

    # Compute & cache ensemble probabilities + labels
    def ensemble_probs(models, X):
        ps = []
        for m in models:
            p, _ = predict_probs_and_ey(m, X)
            ps.append(p)
        return np.mean(ps, axis=0)

    p_val_avg  = ensemble_probs(models, X_val_m)
    p_test_avg = ensemble_probs(models, X_test)

    np.save(VAL_PROBS_NPY, p_val_avg)
    np.save(TEST_PROBS_NPY, p_test_avg)
    np.save(VAL_Y_NPY, y_val)
    np.save(TEST_Y_NPY, y_test)

    # Write meta fingerprint
    json.dump(meta_now, open(META_JSON, "w", encoding="utf-8"),
              ensure_ascii=False, indent=2)

# ===================== Post-processing Calibration (extreme-class friendly) =====================
print("\n[CALIB] Performing extreme-class-friendly post-processing search ...")
best_sel = search_calibration_ext(y_val, p_val_avg, weights=OBJ_WEIGHTS)

# Optional: fine-tune only the last cut to boost class-4 recall (with safety constraints)
base_acc = base_qwk = base_macro = None
from sklearn.metrics import recall_score

if best_sel.get("mode", "cut") == "cut" and "cuts" in best_sel:
    classes = np.arange(NUM_CLASSES)
    def _eval_on(y_true, p_mat, sel):
        y_hat = apply_selection(sel, p_mat)
        acc   = accuracy_score(y_true, y_hat)
        qwk   = quadratic_weighted_kappa(y_true, y_hat, C=NUM_CLASSES)
        macro = f1_score(y_true, y_hat, average='macro', zero_division=0)
        rec4  = recall_score(y_true, y_hat, labels=[4], average='macro', zero_division=0)
        return acc, qwk, macro, rec4

    base_acc, base_qwk, base_macro, base_rec4 = _eval_on(y_val, p_val_avg, best_sel)
    cuts0 = list(best_sel["cuts"])
    best_cand = None

    for delta in np.linspace(-0.35, 0.0, 36):
        c2 = cuts0.copy()
        c2[-1] = cuts0[-1] + float(delta)
        if len(c2) >= 2 and c2[-1] <= c2[-2] + 1e-6:
            continue
        sel2 = dict(best_sel)
        sel2["cuts"] = tuple(c2)
        acc, qwk, macro, rec4 = _eval_on(y_val, p_val_avg, sel2)
        safe = (acc >= base_acc - 0.005) and (qwk >= base_qwk - 0.010)
        if not safe:
            continue
        if (best_cand is None or
            (rec4 > best_cand["rec4"] + 1e-9) or
            (abs(rec4 - best_cand["rec4"]) < 1e-9 and
             (acc > best_cand["acc"] or
              (abs(acc - best_cand["acc"]) < 1e-12 and qwk > best_cand["qwk"]))))):
            best_cand = dict(sel=sel2, acc=acc, qwk=qwk, macro=macro, rec4=rec4, delta=delta)

    if best_cand is not None:
        print(f"[t4-tune] Applying Δt4={best_cand['delta']:.3f} | "
              f"Val Acc {base_acc:.4f} → {best_cand['acc']:.4f}, "
              f"QWK {base_qwk:.4f} → {best_cand['qwk']:.4f}, "
              f"Rec4 {base_rec4:.4f} → {best_cand['rec4']:.4f}")
        best_sel = best_cand["sel"]

# Save final calibration parameters
json.dump(best_sel, open(BEST_SEL_JSON, "w", encoding="utf-8"),
          ensure_ascii=False, indent=2)
print("[CALIB] Best selection (final):", best_sel)

# ===== Apply to test set & report =====
y_pred_cal = apply_selection(best_sel, p_test_avg)

print("\n== After calibration (final) ==")
print(classification_report(y_test, y_pred_cal, digits=4, zero_division=0))
print("MAE (calibrated):", mean_absolute_error(y_test, y_pred_cal))
print("QWK (calibrated):", quadratic_weighted_kappa(y_test, y_pred_cal, C=NUM_CLASSES))

# Visualization: row-normalized confusion matrix
classes = np.arange(NUM_CLASSES)
cm = confusion_matrix(y_test, y_pred_cal, labels=classes)
row_sums = cm.sum(axis=1, keepdims=True) + 1e-12
cm_norm = cm / row_sums

plt.figure(figsize=(6,5))
im = plt.imshow(cm_norm, interpolation='nearest', aspect='auto')
plt.colorbar(im, fraction=0.046, pad=0.04)
plt.title("Confusion Matrix (Row-normalized)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.xticks(classes)
plt.yticks(classes)
for i in range(NUM_CLASSES):
    for j in range(NUM_CLASSES):
        val = cm_norm[i, j] * 100
        plt.text(j, i, f"{val:.1f}%", ha="center", va="center", fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_confusion_matrix_rownorm_extcal.pdf"), format="pdf")
plt.close()

# Adjacent accuracy
adj_acc = (np.abs(y_test - y_pred_cal) <= 1).mean()
print(f"Adjacent Accuracy = {adj_acc:.4f}")

# ===================== SHAP: E[y]_rev (requires assets to exist) =====================
def ey_rev_from_score(score, alpha, tau=TAU):
    logits = (score - alpha[tf.newaxis,:]) / tau
    probs_gt = tf.sigmoid(logits)
    p0 = 1 - probs_gt[:, :1]
    p_mid = probs_gt[:, :-1] - probs_gt[:, 1:] if probs_gt.shape[1] > 1 else tf.zeros((tf.shape(score)[0], 0), score.dtype)
    plast = probs_gt[:, -1:]
    p = tf.concat([p0, p_mid, plast], axis=1)
    p = tf.clip_by_value(p, 1e-7, 1.0)
    p = p / tf.reduce_sum(p, axis=1, keepdims=True)
    C = tf.shape(p)[1]
    classes_rev = tf.cast(tf.range(C-1, -1, -1), p.dtype)
    ey_rev = tf.reduce_sum(p * classes_rev, axis=1, keepdims=True)
    return ey_rev

if assets_exist():
    print("[SHAP] Loading assets and reconstructing model/data for SHAP computation ...")

    # Load parameters & feature names
    params_all = json.load(open(MODEL_PARAMS_JSON, "r", encoding="utf-8"))
    best_params = params_all["best_params"]
    feature_names = json.load(open(FEATURE_NAMES_JSON, "r", encoding="utf-8"))
    protein_names = feature_names["protein"]
    covariates_names = feature_names["covariates"]
    num_proteins = len(protein_names)

    # Load indices & preprocessors, reconstruct X_test (consistent with training)
    idx_test = np.load(IDX_TEST_NPY)
    df = pd.read_csv(DATA_PATH)
    Xp_test_raw = df.iloc[idx_test, PROTEIN_SLICE].values.astype(np.float32)
    Xc_test_raw = df.iloc[idx_test, COVARIATE_SLICE].values.astype(np.float32)

    imp_p = joblib.load(IMP_P_PKL)
    imp_c = joblib.load(IMP_C_PKL)
    sc_p  = joblib.load(SC_P_PKL)
    sc_c  = joblib.load(SC_C_PKL)
    Xp_te = sc_p.transform(imp_p.transform(Xp_test_raw))
    Xc_te = sc_c.transform(imp_c.transform(Xc_test_raw))
    X_test_for_shap = np.hstack([Xp_te, Xc_te]).astype(np.float32)

    # Rebuild model (single) and load ens0 weights
    tf.keras.backend.clear_session()
    m0 = OrdinalLDL(
        input_dim=X_test_for_shap.shape[1],
        num_classes=NUM_CLASSES,
        layer_sizes=best_params["layers"],
        dropout_rate=best_params["dropout"],
        l2_lambda=best_params["l2"],
        protein_indices=np.arange(num_proteins),
        tau=params_all.get("tau", TAU),
        emd_lambda=1.0,
        entropy_lambda=0.10,
        kl_lambda=1.2,
        train_prior=None,
        mono_mode=params_all.get("mono_mode", MONO_MODE)
    )
    m0.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=best_params["lr"], clipvalue=1.0))
    m0.build((None, X_test_for_shap.shape[1]))
    m0.load_weights(ENS0_WEIGHTS_H5)

    # Build Keras model that outputs E[y]_rev
    inp = tf.keras.Input(shape=(X_test_for_shap.shape[1],), dtype=tf.float32)
    score = m0(inp, training=False)
    alpha_tf = m0.get_alphas()
    ey_rev_out = tf.keras.layers.Lambda(
        lambda s: ey_rev_from_score(s, alpha_tf, TAU),
        name="expected_severity_rev"
    )(score)
    EyRev_model = tf.keras.Model(inputs=inp, outputs=ey_rev_out)

    # Background & subset
    background = shap.kmeans(X_test_for_shap, 100).data
    n_samples = min(2000, len(X_test_for_shap))
    test_subset = X_test_for_shap[:n_samples].astype(np.float32)

    explainer = shap.GradientExplainer(EyRev_model, background)
    sv = explainer.shap_values(test_subset)
    if isinstance(sv, list):
        sv = sv[0]
    sv = np.asarray(sv)
    if sv.ndim == 3 and sv.shape[-1] == 1:
        sv = sv[...,0]
    elif sv.ndim != 2:
        raise ValueError(f"Unexpected SHAP shape {sv.shape}")
    print("SHAP values shape:", sv.shape)

    # Export importance
    sv_prot = sv[:, :num_proteins]
    mean_abs = np.mean(np.abs(sv_prot), axis=0).ravel()
    pd.DataFrame({"feature": protein_names, "mean_abs_shap_Ey_rev": mean_abs}) \
      .to_csv(os.path.join(OUT_DIR, "protein_importance_Ey_rev.csv"),
              index=False, encoding="utf-8-sig")
    print("Saved: protein_importance_Ey_rev.csv")

    # Top-K dependence plots
    TOPK = 10
    sorted_idx = np.argsort(-mean_abs)
    for rank, j in enumerate(sorted_idx[:TOPK], 1):
        plt.figure()
        shap.dependence_plot(
            ind=j,
            shap_values=sv,
            features=test_subset,
            feature_names=protein_names + covariates_names,
            interaction_index=None,
            show=False
        )
        plt.title(f"E[y]_rev SHAP Dependence — Top{rank}: {protein_names[j]}")
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, f"dep_Eyrev_top{rank}_{protein_names[j]}.pdf"),
                    format="pdf")
        plt.close()
    print("Saved: dep_Eyrev_top*.pdf")

    # Overall beeswarm (proteins only)
    plt.figure()
    shap.summary_plot(
        sv[:, :num_proteins],
        test_subset[:, :num_proteins],
        feature_names=protein_names,
        show=False
    )
    plt.title("E[y]_rev SHAP Summary (Proteins)")
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, "shap_summary_beeswarm_Ey_rev_proteins.pdf"),
                format="pdf")
    plt.close()
    print("Saved: shap_summary_beeswarm_Ey_rev_proteins.pdf")

else:
    print("[SHAP] Warning: missing model assets (preprocessor/weights/params/indices/feature names). "
          "Skipping SHAP this time.\n"
          "To run SHAP, set FORCE_TRAIN=True and run once, or delete cache to force retraining.")

print("\n✅ Pipeline completed. Outputs are located in:", OUT_DIR)

In [None]:
# === Stage-C: Three-strategy joint enhancement of the tail (class 4), post-processing only ===
# Parameters are tuned on validation set and applied to test set

from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix

# Allowed performance degradation relative to the baseline E[y] cut (best_sel on validation)
MAX_ACC_DROP = 0.015   # 1.5%
MAX_QWK_DROP = 0.020   # 0.02

classes = np.arange(NUM_CLASSES)

def eval_metrics(y_true, y_hat):
    acc  = accuracy_score(y_true, y_hat)
    qwk  = quadratic_weighted_kappa(y_true, y_hat, C=NUM_CLASSES)
    macro = f1_score(y_true, y_hat, average='macro', zero_division=0)
    rec4 = recall_score(y_true, y_hat, labels=[4], average='macro', zero_division=0)
    rec3 = recall_score(y_true, y_hat, labels=[3], average='macro', zero_division=0)
    return dict(acc=acc, qwk=qwk, macro=macro, rec4=rec4, rec3=rec3)

# Baseline: performance of current best_sel (E[y] cut) on validation set
y_hat_base_val = apply_selection(best_sel, p_val_avg)
base = eval_metrics(y_val, y_hat_base_val)

def apply_selection_cdf(params, p_mat):
    """ Ordered CDF decision: take the largest k where S_k >= t_k; only t1..t4 used, t4 directly on p4 """
    C = NUM_CLASSES
    p_g = softmax_power(p_mat, params['gamma'])
    S = np.cumsum(p_g[:, ::-1], axis=1)[:, ::-1]   # S[:,k] = sum_{j>=k} p_j
    t = np.array(params['thresh'], dtype=float)     # length C (t0 placeholder, t4 acts on S4=p4)
    y_hat = np.zeros(p_g.shape[0], dtype=int)
    for k in range(1, C):
        y_hat = np.where(S[:, k] >= t[k], np.maximum(y_hat, k), y_hat)
    return y_hat

def reweight_probs(p, w_vec):
    """ Multiply by class weights and renormalize; w_vec length = C """
    q = p * w_vec[None, :]
    q = q / (q.sum(axis=1, keepdims=True) + 1e-12)
    return q

def choose_better(current_best, candidate, why):
    """ Maximize Rec4 first; ties broken by QWK > ACC > Macro; record strategy origin """
    if current_best is None:
        cand = dict(candidate)
        cand['_why'] = why
        return cand
    a, b = current_best, candidate
    if (b['rec4'] > a['rec4'] + 1e-12 or
        (abs(b['rec4'] - a['rec4']) <= 1e-12 and (
            b['qwk'] > a['qwk'] + 1e-12 or
            (abs(b['qwk'] - a['qwk']) <= 1e-12 and (
                b['acc'] > a['acc'] + 1e-12 or
                (abs(b['acc'] - a['acc']) <= 1e-12 and b['macro'] > a['macro'] + 1e-12)
            ))
        ))):
        cand = dict(candidate)
        cand['_why'] = why
        return cand
    return current_best

best_pick = None

# ---------- Strategy 1: CDF joint thresholds (t3 & t4) + gamma sharpening ----------
gamma_base = float(best_sel.get('gamma', 1.0))
gamma_mult_list = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5]
t3_grid = np.linspace(0.45, 0.65, 9)    # tuning t3 reduces 4→3 misclassifications
t4_grid = np.linspace(0.25, 0.60, 15)   # tuning t4 directly boosts Rec4

for gmul in gamma_mult_list:
    gamma_try = gamma_base * gmul
    for t3 in t3_grid:
        for t4 in t4_grid:
            t = [1.0, 0.50, 0.50, float(t3), float(t4)]  # t1,t2 commonly set to 0.5
            params = {"mode": "cdf", "gamma": gamma_try, "thresh": t}
            y_hat_val = apply_selection_cdf(params, p_val_avg)
            m = eval_metrics(y_val, y_hat_val)
            # Safety constraints
            if (m['acc'] >= base['acc'] - MAX_ACC_DROP) and (m['qwk'] >= base['qwk'] - MAX_QWK_DROP):
                best_pick = choose_better(best_pick, dict(m, params=params), "CDF(t3,t4)+gamma")

# ---------- Strategy 2: Reweight p4 (can slightly suppress p3), then apply E[y] cut ----------
a0, b0 = float(best_sel.get('a', 1.0)), float(best_sel.get('b', 0.0))
cuts0 = best_sel.get('cuts', None)
for gmul in [1.0, 1.1, 1.2, 1.3, 1.4]:
    gamma_try = gamma_base * gmul
    p_g_val = softmax_power(p_val_avg, gamma_try)
    for w4 in [1.0, 1.2, 1.4, 1.6, 1.8]:
        for w3 in [1.0, 0.95, 0.90, 0.85]:
            w = np.ones(NUM_CLASSES, dtype=float)
            w[4] = w4
            w[3] = w3
            p_rw = reweight_probs(p_g_val, w)
            ey = (p_rw * classes[None, :]).sum(axis=1)
            y_hat_val = digitize_with_cuts(a0 * ey + b0, cuts0)
            m = eval_metrics(y_val, y_hat_val)
            if (m['acc'] >= base['acc'] - MAX_ACC_DROP) and (m['qwk'] >= base['qwk'] - MAX_QWK_DROP):
                best_pick = choose_better(best_pick,
                                          dict(m, params={"mode": "ey+reweight", "gamma": gamma_try, "w": w.tolist()}),
                                          "Reweight(p4↑,p3↓)+E[y]")

# ---------- Strategy 3: Two-stage hierarchy: first 4-vs-rest (S4=p4 threshold), rest use original E[y] ----------
for gmul in [1.0, 1.1, 1.2, 1.3]:
    gamma_try = gamma_base * gmul
    p_g_val = softmax_power(p_val_avg, gamma_try)
    p4 = p_g_val[:, 4]
    for th4 in np.linspace(0.25, 0.65, 17):
        mask4 = (p4 >= th4)
        # First apply original E[y] decision, then override mask4 to class 4
        y_hat_val = apply_selection(best_sel, p_val_avg).copy()  # keep original gamma-based E[y] cut as base
        y_hat_val[mask4] = 4
        m = eval_metrics(y_val, y_hat_val)
        if (m['acc'] >= base['acc'] - MAX_ACC_DROP) and (m['qwk'] >= base['qwk'] - MAX_QWK_DROP):
            best_pick = choose_better(best_pick,
                                      dict(m, params={"mode": "hier4", "gamma": gamma_try, "th4": float(th4)}),
                                      "Hierarchical 4-vs-rest")

# ---------- Select best and apply to test set ----------
if best_pick is None:
    print("[Stage-C] No method found that further improves Rec4 within safety constraints; keeping original E[y] cut.")
    y_pred_cal = apply_selection(best_sel, p_test_avg)
else:
    print(f"[Stage-C] Selected strategy: {best_pick['_why']} | "
          f"Val: Acc {base['acc']:.4f} → {best_pick['acc']:.4f}, "
          f"QWK {base['qwk']:.4f} → {best_pick['qwk']:.4f}, "
          f"Rec4 {base['rec4']:.4f} → {best_pick['rec4']:.4f}")
    mode = best_pick['params']['mode']

    if mode == "cdf":
        y_pred_cal = apply_selection_cdf(best_pick['params'], p_test_avg)
        best_sel = dict(best_sel)
        best_sel.update(best_pick['params'])
    elif mode == "ey+reweight":
        p_g_test = softmax_power(p_test_avg, best_pick['params']['gamma'])
        p_rw_test = reweight_probs(p_g_test, np.array(best_pick['params']['w'], dtype=float))
        ey = (p_rw_test * classes[None, :]).sum(axis=1)
        y_pred_cal = digitize_with_cuts(a0 * ey + b0, cuts0)
        best_sel = dict(best_sel)
        best_sel.update(best_pick['params'])
    elif mode == "hier4":
        p_g_test = softmax_power(p_test_avg, best_pick['params']['gamma'])
        y_pred_cal = apply_selection(best_sel, p_test_avg).copy()
        y_pred_cal[p_g_test[:, 4] >= best_pick['params']['th4']] = 4
        best_sel = dict(best_sel)
        best_sel.update(best_pick['params'])
    else:
        # Fallback (should not reach here)
        y_pred_cal = apply_selection(best_sel, p_test_avg)

# —— Print 4→3 confusion fraction on test set for verification —— #
cm_test = confusion_matrix(y_test, y_pred_cal, labels=np.arange(NUM_CLASSES))
if cm_test[4].sum() > 0:
    frac_4_to_3 = cm_test[4, 3] / cm_test[4].sum()
    print(f"[Test] Class-4 accuracy={cm_test[4,4]/cm_test[4].sum():.3f}, "
          f"4→3 misclassification fraction={frac_4_to_3:.3f} (row-normalized)")
else:
    print("[Test] No true class-4 samples in the test set.")

In [None]:
# —— Finalize and save this scheme (save best_sel + test set predictions) ——
import os
import json
import numpy as np
from sklearn.metrics import classification_report

# (Optional) If you saved a baseline before Stage-C: best_sel_base = dict(best_sel_before_stageC)
# Here we save it together for future comparison (if it exists)
artifacts = {
    "best_sel_final": best_sel,   # Already includes mode, gamma, th4, etc. if Hier4 was selected
}

# Save final configuration as JSON
with open(os.path.join(OUT_DIR, "best_selection_final.json"), "w", encoding="utf-8") as f:
    json.dump(artifacts, f, ensure_ascii=False, indent=2)

# Save final test predictions
np.save(os.path.join(OUT_DIR, "y_pred_test_final.npy"), y_pred_cal)

# Print and save test set report
rep_new = classification_report(y_test, y_pred_cal, digits=4, zero_division=0)
print("\n[TEST REPORT — Hier4 (final)]\n", rep_new)

with open(os.path.join(OUT_DIR, "test_report_hier4_final.txt"), "w", encoding="utf-8") as f:
    f.write(rep_new)

In [None]:
# ===== Fixed Version: Load Final Scheme + Reproduce Predictions + Generate All Plots =====
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from sklearn.isotonic import IsotonicRegression
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_curve,
    auc,
    precision_recall_curve,
    average_precision_score
)

# >>> 1) Align paths to your training/saving directory <<< (very important)
OUT_DIR = "./results/Gradually_Increasing"

# Cache and final selection files (note the filenames)
VAL_PROBS_NPY  = os.path.join(OUT_DIR, "cache_val_probs.npy")
TEST_PROBS_NPY = os.path.join(OUT_DIR, "cache_test_probs.npy")
VAL_Y_NPY      = os.path.join(OUT_DIR, "cache_val_y.npy")
TEST_Y_NPY     = os.path.join(OUT_DIR, "cache_test_y.npy")
BEST_SEL_FINAL = os.path.join(OUT_DIR, "best_selection_final.json")  # Use the final file

# --- Helper Functions ---
def softmax_power(p, gamma):
    p = np.clip(p, 1e-8, 1.0)
    z = p ** gamma
    return z / (z.sum(axis=1, keepdims=True) + 1e-8)

def digitize_with_cuts(ey, cuts):
    bins = [-np.inf] + list(cuts) + [np.inf]
    return np.digitize(ey, bins) - 1

def ey_from_probs(p, gamma, a, b):
    classes = np.arange(p.shape[1])
    p_g = softmax_power(p, gamma)
    ey = (p_g * classes[None, :]).sum(axis=1)
    return a * ey + b, p_g

def apply_selection(best_sel, p_val, p_test, y_val=None):
    """
    Supports 3 modes:
    - 'cut'  : Use ey + cutpoints
    - 'iso'  : Fit isotonic regression on ey (requires y_val/p_val)
    - 'hier4': First apply th4 threshold for 4-vs-rest, others keep cut (or iso) decision
    Returns: y_pred_test, p_curve_test (for ROC), ey_test
    """
    mode = best_sel.get('mode', 'cut').lower()
    gamma = float(best_sel.get('gamma', 1.0))
    a = float(best_sel.get('a', 1.0))
    b = float(best_sel.get('b', 0.0))

    # Compute ey and temperature-scaled probabilities first
    ey_val,  p_val_g  = ey_from_probs(p_val,  gamma, a, b)
    ey_test, p_test_g = ey_from_probs(p_test, gamma, a, b)

    if mode == 'iso':
        if y_val is None:
            raise ValueError("mode='iso' requires y_val to fit isotonic regression on validation set.")
        iso = IsotonicRegression(y_min=0.0, y_max=float(p_test.shape[1]-1),
                                 increasing=True, out_of_bounds="clip")
        iso.fit(ey_val, y_val.astype(float))
        y_pred_test = np.rint(iso.predict(ey_test)).astype(int).clip(0, p_test.shape[1]-1)

    elif mode == 'cut':
        cuts = best_sel['cuts']
        y_pred_test = digitize_with_cuts(ey_test, cuts)

    elif mode == 'hier4':
        # Base: first apply cut (or iso base) for non-class-4 predictions
        if 'cuts' in best_sel:
            y_pred_test = digitize_with_cuts(ey_test, best_sel['cuts'])
        elif best_sel.get('iso_base', False):
            # Rare branch: if you saved an iso baseline, recover it here (default: not present)
            if y_val is None:
                raise ValueError("hier4 + iso_base requires y_val to fit isotonic regression.")
            iso = IsotonicRegression(y_min=0.0, y_max=float(p_test.shape[1]-1),
                                     increasing=True, out_of_bounds="clip")
            iso.fit(ey_val, y_val.astype(float))
            y_pred_test = np.rint(iso.predict(ey_test)).astype(int).clip(0, p_test.shape[1]-1)
        else:
            # Fallback to cut (most common case)
            cuts = best_sel['cuts']
            y_pred_test = digitize_with_cuts(ey_test, cuts)

        # Top-level threshold: Y≥4 vs rest, using s4 = P(Y=4)
        th4 = float(best_sel['th4']) if 'th4' in best_sel else 0.5
        s4 = p_test_g[:, -1]             # probability of class 4
        pick4 = (s4 >= th4)
        y_pred_test[pick4] = p_test.shape[1]-1  # set to 4

    else:
        raise ValueError(f"Unknown mode: {mode}")

    return y_pred_test, p_test_g, ey_test

# --- Load cache and final selection ---
assert os.path.exists(TEST_PROBS_NPY), f"File not found: {TEST_PROBS_NPY}"
assert os.path.exists(BEST_SEL_FINAL), f"File not found: {BEST_SEL_FINAL} (please run 'Finalize and save scheme' first)"
p_val_avg  = np.load(VAL_PROBS_NPY)
p_test_avg = np.load(TEST_PROBS_NPY)
y_val      = np.load(VAL_Y_NPY)
y_test     = np.load(TEST_Y_NPY)

with open(BEST_SEL_FINAL, "r", encoding="utf-8") as f:
    js = json.load(f)
# Handle both save formats: {"best_sel_final": {...}} or direct {...}
best_sel = js.get("best_sel_final", js)

NUM_CLASSES = p_test_avg.shape[1]

print("[INFO] OUT_DIR =", OUT_DIR)
print("[INFO] best_sel.mode =", best_sel.get("mode"), "| keys:", list(best_sel.keys()))

# Generate final predictions (supports hier4)
y_pred_cal, p_curve, ey_curve = apply_selection(best_sel, p_val_avg, p_test_avg, y_val=y_val)

# ====== Quick sanity check: row-normalized class 4 ======
cm = confusion_matrix(y_test, y_pred_cal, labels=np.arange(NUM_CLASSES))
row = cm[-1] / max(cm[-1].sum(), 1)  # last row = true class 4
print(f"[CHECK] Class-4 row-normalized: {np.round(row, 3)}  (cols=0..4)  | Accuracy = {row[-1]:.3f}")

# ====== Generate all plots (consistent with your previous scripts) ======
classes = np.arange(NUM_CLASSES)

# 1) Confusion Matrix (row-normalized)
row_sums = cm.sum(axis=1, keepdims=True) + 1e-12
cm_norm = cm / row_sums
plt.figure(figsize=(6,5))
im = plt.imshow(cm_norm, interpolation='nearest', aspect='auto')
plt.colorbar(im, fraction=0.046, pad=0.04)
tag = f"Mode={best_sel.get('mode','?')}, γ={best_sel.get('gamma',1.0):.2f}"
if best_sel.get('mode','').lower() == 'hier4':
    tag += f", th4={best_sel.get('th4','-')}"
plt.title(f"Confusion Matrix (Row-normalized)\n{tag}")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.xticks(classes)
plt.yticks(classes)
for i in range(NUM_CLASSES):
    for j in range(NUM_CLASSES):
        val = cm_norm[i, j] * 100
        plt.text(j, i, f"{val:.1f}%", ha="center", va="center", fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_confusion_matrix_rownorm_FINAL.pdf"), format="pdf")
plt.close()

# 2) One-vs-Rest ROC
plt.figure(figsize=(6,5))
macro_auc = []
for k in classes:
    y_true_k = (y_test == k).astype(int)
    fpr, tpr, _ = roc_curve(y_true_k, p_curve[:, k])
    auc_k = auc(fpr, tpr)
    macro_auc.append(auc_k)
    plt.plot(fpr, tpr, label=f"class {k} (AUC={auc_k:.3f})")
plt.plot([0,1], [0,1], '--')
plt.title(f"OVR ROC — {tag}")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.legend(fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_roc_ovr_FINAL.pdf"), format="pdf")
plt.close()
print(f"OVR macro AUC: {np.mean(macro_auc):.3f}")

# 3) Ordinal ROC (cumulative Y≥k)
plt.figure(figsize=(6,5))
auc_ord = []
for k in range(1, NUM_CLASSES):
    y_true_bin = (y_test >= k).astype(int)
    s_k = p_curve[:, k:].sum(axis=1)
    fpr, tpr, _ = roc_curve(y_true_bin, s_k)
    auc_k = auc(fpr, tpr)
    auc_ord.append(auc_k)
    plt.plot(fpr, tpr, label=f"Y≥{k} (AUC={auc_k:.3f})")
plt.plot([0,1], [0,1], '--')
plt.title(f"Ordinal ROC (cumulative) — {tag}")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.legend(fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_roc_ordinal_cumulative_FINAL.pdf"), format="pdf")
plt.close()
print(f"Ordinal ROC mean AUC: {np.mean(auc_ord):.3f}")

# 4) One-vs-Rest Precision-Recall
plt.figure(figsize=(6,5))
auprc = []
for k in classes:
    y_true_k = (y_test == k).astype(int)
    prec, rec, _ = precision_recall_curve(y_true_k, p_curve[:, k])
    ap = average_precision_score(y_true_k, p_curve[:, k])
    auprc.append(ap)
    plt.plot(rec, prec, label=f"class {k} (AP={ap:.3f})")
plt.title(f"OVR Precision-Recall — {tag}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend(fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_pr_ovr_FINAL.pdf"), format="pdf")
plt.close()
print(f"OVR mean AUPRC: {np.mean(auprc):.3f}")

# 5) E[y] Calibration (binned means)
ey_plot = (p_curve * classes[None, :]).sum(axis=1)  # Here using γ-scaled but un-affine E[y]; use ey_curve if you want a,b applied
q = np.quantile(ey_plot, np.linspace(0, 1, 11))
idx = np.digitize(ey_plot, q[1:-1], right=True)
bin_pred, bin_true, bin_n = [], [], []
for b in range(len(q)-1):
    m = (idx == b)
    if m.sum() >= 30:
        bin_pred.append(ey_plot[m].mean())
        bin_true.append(y_test[m].mean())
        bin_n.append(m.sum())
plt.figure(figsize=(6,5))
plt.plot([0, NUM_CLASSES-1], [0, NUM_CLASSES-1], linestyle="--")
plt.scatter(bin_pred, bin_true, s=np.clip(np.array(bin_n)/2, 10, 200))
for bp, bt, n in zip(bin_pred, bin_true, bin_n):
    plt.text(bp, bt, str(n), fontsize=8, ha='left', va='bottom')
plt.title(f"E[y] Calibration (bin means) — {tag}")
plt.xlabel("Predicted E[y]")
plt.ylabel("Observed mean(y)")
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_calibration_Ey_bins_FINAL.pdf"), format="pdf")
plt.close()

# 6) Adjacent Accuracy by True Class
adj_ok = (np.abs(y_pred_cal - y_test) <= 1).astype(float)
vals = []
for k in classes:
    m = (y_test == k)
    vals.append(adj_ok[m].mean() if m.any() else np.nan)
plt.figure(figsize=(6,4))
plt.bar(classes, vals)
plt.ylim(0, 1)
plt.title(f"Adjacent Accuracy by True Class — {tag}")
plt.xlabel("True class")
plt.ylabel("Proportion")
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_adjacent_accuracy_by_class_FINAL.pdf"), format="pdf")
plt.close()

# 7) Error Distribution
err = y_pred_cal - y_test
plt.figure(figsize=(6,4))
bins = np.arange(-(NUM_CLASSES-1)-0.5, (NUM_CLASSES-1)+1.5, 1.0)
plt.hist(err, bins=bins, density=True)
plt.xticks(range(-(NUM_CLASSES-1), (NUM_CLASSES-1)+1))
plt.title(f"Error distribution (ŷ−y) — {tag}")
plt.xlabel("Error")
plt.ylabel("Density")
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_error_hist_FINAL.pdf"), format="pdf")
plt.close()

# 8) QWK Cost Heatmap
W = np.zeros((NUM_CLASSES, NUM_CLASSES))
for i in range(NUM_CLASSES):
    for j in range(NUM_CLASSES):
        W[i, j] = ((i - j) / (NUM_CLASSES - 1)) ** 2
cost = W * (cm / (cm.sum() + 1e-12))
plt.figure(figsize=(6,5))
im = plt.imshow(cost, interpolation='nearest', aspect='auto')
plt.colorbar(im, fraction=0.046, pad=0.04)
plt.title(f"QWK Cost Map — {tag}")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.xticks(classes)
plt.yticks(classes)
for i in range(NUM_CLASSES):
    for j in range(NUM_CLASSES):
        plt.text(j, i, f"{cost[i,j]*100:.2f}%", ha="center", va="center", fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "viz_qwk_cost_map_FINAL.pdf"), format="pdf")
plt.close()

print("\n✅ All new plots have been saved to:", OUT_DIR)

In [None]:
# -*- coding: utf-8 -*-
"""
ROC Comparison (Standalone Script)
- Loads probability cache from your deep ordinal model + calibration parameters
  (supports best_selection_final.json / best_selection.json)
- Reconstructs the same train/test split used in training
- Trains baseline models: RandomForest, XGBoost, Ordered Logit
- Plots Ordinal ROC curves (cumulative Y≥k) and overlays "DeepOrdinal(γ+Hier4 aware)" curve
- Outputs: compare_ordinal_roc_with_hier4.pdf, compare_ordinal_auc_bar_with_hier4.pdf
"""

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import roc_curve, auc

# ========= 1) Paths & Data Configuration (must match training/saving setup) =========
DATA_PATH = "./data/extracted_Gradually_Increasing.csv"
OUT_DIR   = "./results/Gradually_Increasing"
os.makedirs(OUT_DIR, exist_ok=True)

VAL_PROBS_NPY  = os.path.join(OUT_DIR, "cache_val_probs.npy")
TEST_PROBS_NPY = os.path.join(OUT_DIR, "cache_test_probs.npy")
VAL_Y_NPY      = os.path.join(OUT_DIR, "cache_val_y.npy")
TEST_Y_NPY     = os.path.join(OUT_DIR, "cache_test_y.npy")
BEST_SEL_FINAL = os.path.join(OUT_DIR, "best_selection_final.json")
BEST_SEL_JSON  = os.path.join(OUT_DIR, "best_selection.json")

# Column slices used during training (FD example)
PROTEIN_SLICE   = slice(9, 76)
COVARIATE_SLICE = slice(2, 9)

RANDOM_STATE = 42
NUM_CLASSES  = 5

# ========= 2) Helper Functions =========
def softmax_power(p, gamma):
    p = np.clip(p, 1e-8, 1.0)
    z = p ** float(gamma)
    return z / (z.sum(axis=1, keepdims=True) + 1e-8)

def load_deep_caches():
    assert os.path.exists(TEST_PROBS_NPY), f"Test probabilities not found: {TEST_PROBS_NPY}"
    p_test_avg = np.load(TEST_PROBS_NPY)
    p_val_avg  = np.load(VAL_PROBS_NPY)  if os.path.exists(VAL_PROBS_NPY)  else None
    y_test     = np.load(TEST_Y_NPY)     if os.path.exists(TEST_Y_NPY)     else None
    y_val      = np.load(VAL_Y_NPY)      if os.path.exists(VAL_Y_NPY)      else None

    # Prefer final scheme; fallback to best_selection.json if missing
    if os.path.exists(BEST_SEL_FINAL):
        with open(BEST_SEL_FINAL, "r", encoding="utf-8") as f:
            obj = json.load(f)
        best_sel = obj.get("best_sel_final", obj)
    elif os.path.exists(BEST_SEL_JSON):
        with open(BEST_SEL_JSON, "r", encoding="utf-8") as f:
            best_sel = json.load(f)
    else:
        raise FileNotFoundError("Neither best_selection_final.json nor best_selection.json found.")

    return p_val_avg, p_test_avg, y_val, y_test, best_sel

def load_split_Xy():
    """Reconstruct the exact same train/test split used during model training."""
    df = pd.read_csv(DATA_PATH)
    Xp = df.iloc[:, PROTEIN_SLICE].values.astype(np.float32)
    Xc = df.iloc[:, COVARIATE_SLICE].values.astype(np.float32)
    y  = df.iloc[:, 1].values.astype(np.int32)

    Xp_tr, Xp_te, Xc_tr, Xc_te, y_tr, y_te = train_test_split(
        Xp, Xc, y, test_size=0.2, random_state=RANDOM_STATE, stratify=y
    )
    imp_p = SimpleImputer(strategy='mean')
    imp_c = SimpleImputer(strategy='mean')
    Xp_tr = imp_p.fit_transform(Xp_tr)
    Xp_te = imp_p.transform(Xp_te)
    Xc_tr = imp_c.fit_transform(Xc_tr)
    Xc_te = imp_c.transform(Xc_te)

    sc_p = StandardScaler()
    sc_c = StandardScaler()
    Xp_tr = sc_p.fit_transform(Xp_tr)
    Xp_te = sc_p.transform(Xp_te)
    Xc_tr = sc_c.fit_transform(Xc_tr)
    Xc_te = sc_c.transform(Xc_te)

    X_train = np.hstack([Xp_tr, Xc_tr]).astype(np.float32)
    X_test  = np.hstack([Xp_te, Xc_te]).astype(np.float32)
    return X_train, X_test, y_tr, y_te

def ordinal_auc_list(y_true, scores_list):
    """Given list of scores for each threshold k (length = C-1), return AUCs and macro average."""
    aucs = []
    for s_k in scores_list:
        y_bin = (y_true >= (len(aucs) + 1)).astype(int)  # corresponds to Y ≥ k
        if len(np.unique(y_bin)) < 2:
            aucs.append(np.nan)
            continue
        fpr, tpr, _ = roc_curve(y_bin, s_k)
        aucs.append(auc(fpr, tpr))
    return aucs, np.nanmean(aucs)

# ========= 3) Load Deep Model Cache =========
p_val_avg, p_test_avg, y_val_cache, y_test_cache, best_sel = load_deep_caches()
assert p_test_avg.shape[1] == NUM_CLASSES, "NUM_CLASSES inconsistent with probability columns"
n_test = p_test_avg.shape[0]

# ========= 4) Reconstruct X/y split for traditional models =========
X_train, X_test, y_train, y_test = load_split_Xy()
# Align labels with cache if lengths match
if y_test_cache is not None and len(y_test_cache) == n_test:
    y_test = y_test_cache  # use cached labels when perfectly aligned

# ========= 5) Deep Model Scores (two variants) =========
# (a) "γ version": S_k = sum_{j>=k} p_j^γ for each threshold k
p_deep_gamma = softmax_power(p_test_avg, best_sel.get('gamma', 1.0))
scores_deep_gamma = [p_deep_gamma[:, k:].sum(axis=1) for k in range(1, NUM_CLASSES)]

# (b) "γ+Hier4 aware": k=1..3 use cumulative prob, k=4 uses affine-adjusted ey
classes = np.arange(NUM_CLASSES)
Ey_gamma = (p_deep_gamma * classes[None, :]).sum(axis=1)
Ey_aff   = best_sel.get('a', 1.0) * Ey_gamma + best_sel.get('b', 0.0)
scores_deep_hier = []
for k in range(1, NUM_CLASSES):
    if k < NUM_CLASSES - 1:
        scores_deep_hier.append(p_deep_gamma[:, k:].sum(axis=1))  # cumulative probability
    else:
        scores_deep_hier.append(Ey_aff)  # coherent scalar reflecting Stage-C direction

# ========= 6) Train Traditional Baseline Models =========
print("[INFO] Training RandomForest ...")
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier(
    n_estimators=600,
    max_depth=None,
    min_samples_leaf=2,
    class_weight='balanced_subsample',
    n_jobs=-1,
    random_state=42
)
rf.fit(X_train, y_train)
p_rf_raw = rf.predict_proba(X_test)
p_rf = np.zeros((len(X_test), NUM_CLASSES))
for j, cls in enumerate(rf.classes_):
    p_rf[:, int(cls)] = p_rf_raw[:, j]
scores_rf = [p_rf[:, k:].sum(axis=1) for k in range(1, NUM_CLASSES)]

# XGBoost (skipped if not available)
scores_xgb = None
try:
    from xgboost import XGBClassifier
    print("[INFO] Training XGBoost ...")
    xgb = XGBClassifier(
        objective='multi:softprob',
        num_class=NUM_CLASSES,
        n_estimators=500,
        learning_rate=0.05,
        max_depth=6,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=1.0,
        tree_method='hist',
        n_jobs=-1,
        random_state=42
    )
    xgb.fit(X_train, y_train, eval_metric='mlogloss', verbose=False)
    p_xgb_raw = xgb.predict_proba(X_test)
    p_xgb = np.zeros((len(X_test), NUM_CLASSES))
    for j, cls in enumerate(xgb.classes_):
        p_xgb[:, int(cls)] = p_xgb_raw[:, j]
    scores_xgb = [p_xgb[:, k:].sum(axis=1) for k in range(1, NUM_CLASSES)]
except Exception as e:
    print("[WARN] XGBoost unavailable, skipped:", e)

# Ordered Logit (skipped if not available)
scores_ologit = None
try:
    from statsmodels.miscmodels.ordinal_model import OrderedModel
    print("[INFO] Training Ordered Logit ...")
    ologit = OrderedModel(y_train, X_train, distr='logit')
    res = ologit.fit(method='bfgs', maxiter=200, disp=False)
    p_ologit = res.model.predict(res.params, exog=X_test, which='prob')  # (n, C)
    scores_ologit = [p_ologit[:, k:].sum(axis=1) for k in range(1, NUM_CLASSES)]
except Exception as e:
    print("[WARN] statsmodels OrderedModel unavailable, skipped:", e)

# ========= 7) Compute AUCs and Plot =========
panel_labels = [f"Y≥{k}" for k in range(1, NUM_CLASSES)]
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes = axes.ravel()

macro_table = {}

def plot_model(ax_idx, name, scores):
    aucs, macro = ordinal_auc_list(y_test, scores)
    macro_table[name] = macro
    for i, (s_k, A) in enumerate(zip(scores, aucs)):
        y_bin = (y_test >= (i + 1)).astype(int)
        if len(np.unique(y_bin)) < 2:
            continue
        fpr, tpr, _ = roc_curve(y_bin, s_k)
        axes[i].plot(fpr, tpr, label=f"{name} (AUC={A:.3f})")

# Plot traditional models first (γ deep, RF, XGB, OLogit)
plot_model(axes, "DeepOrdinal(γ)", scores_deep_gamma)
plot_model(axes, "RandomForest",  scores_rf)
if scores_xgb is not None:
    plot_model(axes, "XGBoost", scores_xgb)
if scores_ologit is not None:
    plot_model(axes, "OrderedLogit", scores_ologit)

# Overlay "γ+Hier4 aware" curve (highlight lift on Y≥4)
aucs_hier, macro_hier = ordinal_auc_list(y_test, scores_deep_hier)
macro_table["DeepOrdinal(γ+Hier4)"] = macro_hier
for i, (s_k, A) in enumerate(zip(scores_deep_hier, aucs_hier)):
    y_bin = (y_test >= (i + 1)).astype(int)
    if len(np.unique(y_bin)) < 2:
        continue
    fpr, tpr, _ = roc_curve(y_bin, s_k)
    axes[i].plot(fpr, tpr, label=f"DeepOrdinal(γ+Hier4) (AUC={A:.3f})", linewidth=2)

# Axes & legends
for ax, lab in zip(axes, panel_labels):
    ax.plot([0,1], [0,1], '--', color='gray', linewidth=1)
    ax.set_title(f"Ordinal ROC — {lab}")
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")
    ax.legend(fontsize=8)

plt.tight_layout()
out1 = os.path.join(OUT_DIR, "compare_ordinal_roc_with_hier4.pdf")
plt.savefig(out1, format="pdf")
plt.close()
print("[OK] Saved:", out1)

# Macro average bar chart (including γ+Hier4)
plt.figure(figsize=(7,4))
names = list(macro_table.keys())
vals  = [macro_table[n] for n in names]
plt.bar(range(len(names)), vals)
plt.xticks(range(len(names)), names, rotation=15)
plt.ylim(0, 1)
plt.ylabel("Mean Ordinal AUC")
plt.title("Mean Ordinal AUC across thresholds (Y≥k)")
plt.tight_layout()
out2 = os.path.join(OUT_DIR, "compare_ordinal_auc_bar_with_hier4.pdf")
plt.savefig(out2, format="pdf")
plt.close()
print("[OK] Saved:", out2)

print("\n== Mean Ordinal AUC ==")
for n in names:
    print(f"{n:>22}: {macro_table[n]:.4f}")

print("\n[Completed] All plots saved to directory:", OUT_DIR)

In [None]:
# === Class-4: Combined PR + Decision Curve Comparison (Deep vs RF/XGB/OrderedLogit) ===
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
from sklearn.metrics import confusion_matrix
from sklearn.ensemble import RandomForestClassifier

# Optional: train temporarily if p_xgb / p_ologit are missing
_need_rf    = 'p_rf' not in globals()
_need_xgb   = 'p_xgb' not in globals()
_need_ologit = 'p_ologit' not in globals()

# -------- Basic Checks --------
required = ['p_test_avg', 'best_sel', 'X_train', 'X_test',
            'y_train', 'y_test', 'NUM_CLASSES', 'OUT_DIR']
missing = [v for v in required if v not in globals()]
assert not missing, f"Missing required variables: {missing}. Please run your deep model code first."
assert NUM_CLASSES >= 2 and (NUM_CLASSES-1) in np.unique(y_test), "Label set must include the highest class."

# -------- Helpers --------
def softmax_power(p, gamma):
    p = np.clip(p, 1e-8, 1.0)
    z = p ** gamma
    return z / (z.sum(axis=1, keepdims=True) + 1e-8)

def ensure_prob_matrix(raw_probs, raw_classes, C):
    """Align model output probabilities (n, len(raw_classes)) to (n, C) in order 0..C-1"""
    P = np.zeros((raw_probs.shape[0], C), dtype=float)
    for j, c in enumerate(raw_classes):
        P[:, int(c)] = raw_probs[:, j]
    return P

# -------- DeepOrdinal Scores (used for PR / Decision Curve) --------
classes = np.arange(NUM_CLASSES)
p_deep_gamma = softmax_power(p_test_avg, best_sel['gamma'])
s4_deep = p_deep_gamma[:, -1]                  # Deep: P(y=4)
y_true4 = (y_test == (NUM_CLASSES-1)).astype(int)
prev4 = y_true4.mean()

models_probs = {"DeepOrdinal(γ)": p_deep_gamma}  # name -> (n, C) probability matrix

# -------- RandomForest Probabilities --------
if _need_rf:
    print("[INFO] Training RandomForest to obtain probabilities ...")
    rf = RandomForestClassifier(
        n_estimators=600,
        max_depth=None,
        min_samples_leaf=2,
        class_weight='balanced_subsample',
        n_jobs=-1,
        random_state=42
    )
    rf.fit(X_train, y_train)
    _p = rf.predict_proba(X_test)
    p_rf = ensure_prob_matrix(_p, rf.classes_, NUM_CLASSES)
models_probs["RandomForest"] = p_rf

# -------- XGBoost Probabilities --------
if _need_xgb:
    print("[INFO] Training XGBoost to obtain probabilities ...")
    from xgboost import XGBClassifier
    xgb = XGBClassifier(
        objective='multi:softprob',
        num_class=NUM_CLASSES,
        n_estimators=500,
        learning_rate=0.05,
        max_depth=6,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=1.0,
        tree_method='hist',
        n_jobs=-1,
        random_state=42
    )
    xgb.fit(X_train, y_train, eval_metric='mlogloss', verbose=False)
    _p = xgb.predict_proba(X_test)
    p_xgb = ensure_prob_matrix(_p, xgb.classes_, NUM_CLASSES)
models_probs["XGBoost"] = p_xgb

# -------- Ordered Logit Probabilities --------
if _need_ologit:
    print("[INFO] Training Ordered Logit to obtain probabilities ...")
    from statsmodels.miscmodels.ordinal_model import OrderedModel
    ologit = OrderedModel(y_train, X_train, distr='logit')
    res = ologit.fit(method='bfgs', maxiter=200, disp=False)
    p_ologit = res.model.predict(res.params, exog=X_test, which='prob')  # (n, C)
models_probs["OrderedLogit"] = p_ologit

# ================= 1) Class-4 Precision–Recall Curve (All Models in One Plot) =================
plt.figure(figsize=(7,5))
auprc_table = {}
for name, P in models_probs.items():
    s4 = P[:, -1]                         # Directly use P(y=4)
    prec, rec, _ = precision_recall_curve(y_true4, s4)
    au = auc(rec, prec)
    auprc_table[name] = au
    plt.plot(rec, prec, label=f"{name} (AUPRC={au:.3f})")

# Plot prevalence reference line (expected precision of random classifier)
plt.hlines(prev4, 0, 1, colors='gray', linestyles='--', linewidth=1,
           label=f"prevalence={prev4:.3f}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Class-4 Precision–Recall (all models)")
plt.legend(fontsize=9)
plt.tight_layout()
out_pr = os.path.join(OUT_DIR, "class4_pr_all_models.pdf")
plt.savefig(out_pr, format="pdf")
plt.close()
print("[OK] Saved PR plot:", out_pr)
print("AUPRC:", {k: f"{v:.3f}" for k, v in auprc_table.items()})

# ================= 2) Class-4 Decision Curve (All Models in One Plot) =================
def decision_curve(y_true, prob, thresholds=np.linspace(0.01, 0.80, 80)):
    y = y_true.astype(int)
    N = len(y)
    NB = []
    for pt in thresholds:
        yhat = (prob >= pt).astype(int)
        TP = np.sum((yhat == 1) & (y == 1))
        FP = np.sum((yhat == 1) & (y == 0))
        w = pt / (1.0 - pt)       # cost weight of false positive
        nb = TP / N - w * FP / N
        NB.append(nb)
    return thresholds, np.array(NB)

plt.figure(figsize=(7,5))
ts_ref = np.linspace(0.01, 0.80, 80)

# Treat-all / Treat-none reference lines
nb_none = np.zeros_like(ts_ref)
nb_all  = prev4 - ts_ref / (1 - ts_ref) * (1 - prev4)
plt.plot(ts_ref, nb_all,  label="Treat-all",  linestyle="--", linewidth=1)
plt.plot(ts_ref, nb_none, label="Treat-none", linestyle="--", linewidth=1)

for name, P in models_probs.items():
    s4 = P[:, -1]
    ts, nb = decision_curve(y_true4, s4, thresholds=ts_ref)
    plt.plot(ts, nb, label=name)

plt.xlabel("Threshold probability (class 4)")
plt.ylabel("Net Benefit")
plt.title("Decision Curve — Class 4 (all models)")
plt.legend(fontsize=9)
plt.tight_layout()
out_dc = os.path.join(OUT_DIR, "class4_decision_curve_all_models.pdf")
plt.savefig(out_dc, format="pdf")
plt.close()
print("[OK] Saved Decision Curve plot:", out_dc)

# ================= 3) Appendix: Suggested Operating Points under PPV Constraints =================
from sklearn.metrics import precision_recall_curve

def best_at_precision(y_true, prob, min_ppv):
    prec, rec, thr = precision_recall_curve(y_true, prob)
    idx = np.where(prec[:-1] >= min_ppv)[0]
    if idx.size == 0:
        return None
    i = idx[np.argmax(rec[idx])]
    return dict(threshold=float(thr[i]),
                precision=float(prec[i]),
                recall=float(rec[i]))

for ppv in [0.20, 0.25, 0.30, 0.35]:
    row = {}
    for name, P in models_probs.items():
        pick = best_at_precision(y_true4, P[:, -1], ppv)
        row[name] = pick
    print(f"\n[Operating Point Suggestion @ PPV ≥ {ppv:.2f}]")
    for k, v in row.items():
        print(f"{k:>16} :", v)

In [None]:
# -*- coding: utf-8 -*-
"""
Run SHAP Only (Gradually_Increasing, direction=inc), No Training Required.
Prerequisite: OUT_DIR already contains trained assets:
  - preproc/*.pkl
  - indices/*.npy
  - ens0.weights.h5
  - best_model_params.json
  - feature_names.json
  - and your data CSV file.
Output: Consistency report + SHAP summary / dependence plots for "consistent proteins" only.
"""

import os
import json
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
import shap
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from sklearn.isotonic import IsotonicRegression
from scipy.stats import spearmanr, kendalltau

# ===================== Paths (as provided) =====================
DATA_PATH = "./data/extracted_Gradually_Increasing.csv"
OUT_DIR   = "./results/Gradually_Increasing"
os.makedirs(OUT_DIR, exist_ok=True)

# Asset files
PREPROC_DIR = os.path.join(OUT_DIR, "preproc")
IDX_DIR     = os.path.join(OUT_DIR, "indices")
IMP_P_PKL   = os.path.join(PREPROC_DIR, "imp_p.pkl")
IMP_C_PKL   = os.path.join(PREPROC_DIR, "imp_c.pkl")
SC_P_PKL    = os.path.join(PREPROC_DIR, "sc_p.pkl")
SC_C_PKL    = os.path.join(PREPROC_DIR, "sc_c.pkl")
IDX_TEST_NPY = os.path.join(IDX_DIR, "idx_test.npy")
MODEL_PARAMS_JSON = os.path.join(OUT_DIR, "best_model_params.json")
ENS0_WEIGHTS_H5   = os.path.join(OUT_DIR, "ens0.weights.h5")
FEATURE_NAMES_JSON = os.path.join(OUT_DIR, "feature_names.json")

# Fallback slice if feature_names.json columns are not found
PROTEIN_SLICE   = slice(9, 63)
COVARIATE_SLICE = slice(2, 9)

# Task & Plotting Options
MONO_MODE_EXPECTED = "inc"   # Gradually_Increasing: higher value → worse condition
BACKGROUND_K       = 100     # Number of clusters for SHAP background
SHAP_SUBSET_N      = 2000    # Number of samples to use for SHAP computation
TOPK_DEP_PLOTS     = 30      # Max number of dependence plots for consistent proteins (None = all)
FORCE_PLOT_DIRECTION = False # For plotting only: flip sign of inconsistent proteins to unify direction (no effect on saved files)

# ===================== Lightweight Model (only for loading weights & forward pass) =====================
class OrdinalLDL(tf.keras.Model):
    # Note: includes mono_mode for construction compatibility; not used here (forward only)
    def __init__(self, input_dim, num_classes, layer_sizes, dropout_rate,
                 l2_lambda, protein_indices=None, tau=3.5,
                 emd_lambda=1.0, entropy_lambda=0.10, kl_lambda=1.2,
                 train_prior=None, mono_mode=None, gaussian_noise_std=0.0):
        super().__init__()
        self.num_classes = num_classes
        self.tau = tf.constant(tau, tf.float32)
        self.backbone = tf.keras.Sequential(name="backbone")
        for s in layer_sizes:
            self.backbone.add(tf.keras.layers.Dense(
                s, activation='relu',
                kernel_regularizer=tf.keras.regularizers.l2(l2_lambda)))
            self.backbone.add(tf.keras.layers.Dropout(dropout_rate))
        self.score_head = tf.keras.layers.Dense(1, activation=None,
                kernel_regularizer=tf.keras.regularizers.l2(l2_lambda), name="score")
        self.alpha_raw = tf.Variable(tf.random.normal([num_classes-1], stddev=0.1),
                                     trainable=True, name="alpha_raw")

    def call(self, inputs, training=False):
        h = self.backbone(inputs, training=training)
        score = self.score_head(h)
        return score

    def get_alphas(self):
        inc = tf.nn.softplus(self.alpha_raw) + 1e-6
        alpha = tf.cumsum(inc)
        alpha = alpha - tf.reduce_mean(alpha)
        return alpha

# ---- Compute probabilities / expected value from score (SHAP target) ----
def probs_from_score(score, alpha, tau):
    logits = (score - alpha[tf.newaxis, :]) / tau
    probs_gt = tf.sigmoid(logits)
    p0 = 1 - probs_gt[:, :1]
    p_mid = probs_gt[:, :-1] - probs_gt[:, 1:] if probs_gt.shape[1] > 1 else tf.zeros((tf.shape(score)[0], 0), score.dtype)
    plast = probs_gt[:, -1:]
    p = tf.concat([p0, p_mid, plast], axis=1)
    p = tf.clip_by_value(p, 1e-7, 1.0)
    return p / tf.reduce_sum(p, axis=1, keepdims=True)

def ey_from_score(score, alpha, tau):
    p = probs_from_score(score, alpha, tau)
    classes = tf.cast(tf.range(tf.shape(p)[1]), p.dtype)  # 0..C-1
    return tf.reduce_sum(p * classes, axis=1, keepdims=True)

# ===================== Asset Existence Check =====================
need = [IMP_P_PKL, IMP_C_PKL, SC_P_PKL, SC_C_PKL,
        IDX_TEST_NPY, MODEL_PARAMS_JSON, ENS0_WEIGHTS_H5, FEATURE_NAMES_JSON]
missing = [p for p in need if not os.path.exists(p)]
if missing:
    raise FileNotFoundError(
        "Missing required training assets for SHAP-only mode:\n" + "\n".join(missing) +
        "\nPlease run the full pipeline script once with FORCE_TRAIN=True to generate them."
    )

# ===================== Load Parameters / Feature Names, Reconstruct X_test =====================
params_all = json.load(open(MODEL_PARAMS_JSON, "r", encoding="utf-8"))
best_params = params_all["best_params"]
TAU = float(params_all.get("tau", 3.5))
NUM_CLASSES = int(params_all.get("num_classes", 5))

feat = json.load(open(FEATURE_NAMES_JSON, "r", encoding="utf-8"))
protein_names = feat["protein"]
covariates_names = feat["covariates"]
num_proteins = len(protein_names)

df = pd.read_csv(DATA_PATH)
# Use columns from feature_names.json; fallback to slice if not found
try:
    prot_idx = [df.columns.get_loc(c) for c in protein_names]
    cov_idx  = [df.columns.get_loc(c) for c in covariates_names]
    Xp_all = df.iloc[:, prot_idx].values.astype(np.float32)
    Xc_all = df.iloc[:, cov_idx ].values.astype(np.float32)
except Exception:
    Xp_all = df.iloc[:, PROTEIN_SLICE  ].values.astype(np.float32)
    Xc_all = df.iloc[:, COVARIATE_SLICE].values.astype(np.float32)

idx_test = np.load(IDX_TEST_NPY)
Xp_test_raw = Xp_all[idx_test]
Xc_test_raw = Xc_all[idx_test]

imp_p = joblib.load(IMP_P_PKL)
imp_c = joblib.load(IMP_C_PKL)
sc_p  = joblib.load(SC_P_PKL)
sc_c  = joblib.load(SC_C_PKL)
Xp_te = sc_p.transform(imp_p.transform(Xp_test_raw))
Xc_te = sc_c.transform(imp_c.transform(Xc_test_raw))
X_test_for_shap = np.hstack([Xp_te, Xc_te]).astype(np.float32)
print("[INFO] X_test_for_shap shape:", X_test_for_shap.shape)

# ===================== Rebuild Model & Load Weights (ens0) =====================
tf.keras.backend.clear_session()
m0 = OrdinalLDL(
    input_dim=X_test_for_shap.shape[1],
    num_classes=NUM_CLASSES,
    layer_sizes=best_params["layers"],
    dropout_rate=best_params["dropout"],
    l2_lambda=best_params["l2"],
    protein_indices=np.arange(num_proteins),
    tau=TAU,
    emd_lambda=1.0,
    entropy_lambda=0.10,
    kl_lambda=1.2,
    train_prior=None,
    mono_mode=None  # forward pass only, not used in training
)
m0.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=best_params["lr"], clipvalue=1.0))
m0.build((None, X_test_for_shap.shape[1]))
m0.load_weights(ENS0_WEIGHTS_H5)

# ===================== Build E[y] Model & Compute SHAP =====================
inp = tf.keras.Input(shape=(X_test_for_shap.shape[1],), dtype=tf.float32)
score = m0(inp, training=False)
alpha_tf = m0.get_alphas()
ey_out = tf.keras.layers.Lambda(lambda s: ey_from_score(s, alpha_tf, TAU),
                                name="expected_severity")(score)
Ey_model = tf.keras.Model(inputs=inp, outputs=ey_out)

# Background samples (k-means clustering)
k = min(BACKGROUND_K, len(X_test_for_shap))
background = shap.kmeans(X_test_for_shap, k).data
n_samples = min(SHAP_SUBSET_N, len(X_test_for_shap))
test_subset = X_test_for_shap[:n_samples].astype(np.float32)

explainer = shap.GradientExplainer(Ey_model, background)
sv = explainer.shap_values(test_subset)
if isinstance(sv, list):
    sv = sv[0]
sv = np.asarray(sv)
if sv.ndim == 3 and sv.shape[-1] == 1:
    sv = sv[..., 0]
assert sv.ndim == 2 and sv.shape[0] == test_subset.shape[0], f"Unexpected SHAP shape: {sv.shape}"
print("[INFO] SHAP values shape:", sv.shape)

# Protein-only SHAP values
sv_prot = sv[:, :num_proteins]
mean_abs = np.mean(np.abs(sv_prot), axis=0).ravel()
pd.DataFrame({"feature": protein_names, "mean_abs_shap_Ey": mean_abs}) \
  .to_csv(os.path.join(OUT_DIR, "protein_importance_Ey.csv"), index=False, encoding="utf-8-sig")
print("Saved: protein_importance_Ey.csv")

# ===================== Consistency Report (GI expects positive correlation) =====================
def shap_dynamic_consistency_report(sv, X, protein_names, num_proteins, out_dir,
                                    Ey_model=None, min_bin_size=40, expected="inc"):
    Xp  = X[:, :num_proteins]
    svp = sv[:, :num_proteins]
    rows = []
    ey_bins_arr = None
    if Ey_model is not None:
        Ey_vals = Ey_model.predict(X.astype(np.float32), verbose=0).ravel()
        ey_bins = pd.qcut(Ey_vals, q=5, labels=[f"Q{i+1}" for i in range(5)])
        ey_bins_arr = np.asarray(ey_bins)

    for j in range(num_proteins):
        x = Xp[:, j]
        y = svp[:, j]
        rho, rho_p = spearmanr(x, y)
        tau, tau_p = kendalltau(x, y)

        slope_trim = np.nan
        if np.std(x) > 1e-8:
            ql, qr = np.quantile(x, [0.1, 0.9])
            m = (x >= ql) & (x <= qr)
            if np.sum(m) > 1:
                slope_trim = np.polyfit(x[m], y[m], 1)[0]

        slope_bins = np.nan
        try:
            bins10 = pd.qcut(x, q=10, duplicates='drop')
            dfb = pd.DataFrame({'x': x, 'y': y, 'bin': bins10}).groupby('bin').mean().reset_index()
            if len(dfb) > 1:
                slope_bins = np.polyfit(dfb['x'].values, dfb['y'].values, 1)[0]
        except Exception:
            pass

        order = np.argsort(x)
        dy = np.diff(y[order])
        mvr_up = float(np.mean(dy > 0)) if dy.size > 0 else np.nan

        r2_iso = np.nan
        try:
            ir = IsotonicRegression(increasing=(expected == "inc"))
            y_iso = ir.fit_transform(x, y)
            r2_iso = 1 - np.sum((y - y_iso)**2) / (np.sum((y - np.mean(y))**2) + 1e-12)
        except Exception:
            pass

        mean_abs_local = float(np.mean(np.abs(y)))

        stage_dir_ok_frac = np.nan
        if ey_bins_arr is not None:
            levels = np.unique(ey_bins_arr)
            ok = tot = 0
            for lvl in levels:
                idx = (ey_bins_arr == lvl)
                if np.sum(idx) >= min_bin_size:
                    rr, _ = spearmanr(x[idx], y[idx])
                    if np.isfinite(rr):
                        tot += 1
                        if (expected == "inc" and rr > 0) or (expected == "dec" and rr < 0):
                            ok += 1
            if tot > 0:
                stage_dir_ok_frac = ok / tot

        rows.append([
            protein_names[j],
            mean_abs_local,
            rho,
            rho_p,
            tau,
            tau_p,
            slope_trim,
            slope_bins,
            mvr_up,
            r2_iso,
            stage_dir_ok_frac
        ])

    cols = [
        'feature', 'mean_abs_shap', 'spearman_rho', 'spearman_p',
        'kendall_tau', 'kendall_p', 'trimmed_linear_slope', 'binned_slope',
        'mvr_up', 'isotonic_R2', 'stage_dir_ok_frac'
    ]
    df = pd.DataFrame(rows, columns=cols)

    sign = +1 if expected == "inc" else -1
    df['direction_ok'] = ((sign * df['spearman_rho'] > 0) & (sign * df['binned_slope'] > 0))
    df['strong'] = (
        df['direction_ok'] &
        (df['spearman_p'] < 0.05) &
        (df['mvr_up'] <= 0.40) &
        (df['isotonic_R2'] >= 0.10) &
        (df['stage_dir_ok_frac'].fillna(1.0) >= 0.60)
    )
    df['dyn_consistency_score'] = (
        (sign * df['spearman_rho'].fillna(0)) *
        df['mean_abs_shap'].fillna(0) *
        (1 - df['mvr_up'].fillna(0)) *
        (df['isotonic_R2'].fillna(0) + 1e-6)
    )
    df.sort_values(['strong', 'dyn_consistency_score', 'mean_abs_shap'],
                   ascending=[False, False, False], inplace=True)

    out_path = os.path.join(out_dir, 'shap_dynamic_consistency_report.csv')
    df.to_csv(out_path, index=False, encoding='utf-8-sig')
    print(f"Saved consistency report: {out_path}")
    return df

report = shap_dynamic_consistency_report(
    sv, test_subset, protein_names, num_proteins,
    OUT_DIR, Ey_model=Ey_model, expected=MONO_MODE_EXPECTED
)

# ===================== Plot Only "Consistent" Proteins =====================
keep_feats = report.loc[report['direction_ok'], 'feature'].tolist()
if not keep_feats:
    print("[WARN] No proteins passed direction consistency check; falling back to top 20 by importance.")
    keep_idx = list(np.argsort(-mean_abs)[:20])
    keep_feats = [protein_names[i] for i in keep_idx]
else:
    keep_idx = [protein_names.index(f) for f in keep_feats]

# Optional: flip sign of inconsistent proteins for plotting only (no effect on saved results)
sv_for_plot = sv.copy()
if FORCE_PLOT_DIRECTION and MONO_MODE_EXPECTED == "inc":
    bad_feats = [f for f in protein_names if f not in keep_feats]
    bad_idx = [protein_names.index(f) for f in bad_feats]
    sv_for_plot[:, bad_idx] *= -1

# ---- Summary plot (consistent proteins only) ----
plt.figure()
shap.summary_plot(
    sv_for_plot[:, keep_idx],
    test_subset[:, keep_idx],
    feature_names=keep_feats,
    show=False
)
plt.title("E[y] SHAP Summary (consistent proteins only; right = heavier, red = high value)")
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "shap_summary_Ey_consistent.pdf"), format="pdf")
plt.close()
print("Saved: shap_summary_Ey_consistent.pdf (consistent proteins only)")

# ---- Dependence plots (consistent proteins only; limit to TOPK if too many) ----
order_by_importance = np.argsort(-mean_abs[keep_idx])
dep_list = [keep_idx[i] for i in order_by_importance]
if TOPK_DEP_PLOTS is not None:
    dep_list = dep_list[:TOPK_DEP_PLOTS]

for j in dep_list:
    plt.figure()
    shap.dependence_plot(
        ind=j,
        shap_values=sv_for_plot,
        features=test_subset,
        feature_names=protein_names + covariates_names,
        interaction_index=None,
        show=False
    )
    plt.title(f"E[y] SHAP Dependence — {protein_names[j]} (right = heavier)")
    plt.tight_layout()
    fname = os.path.join(OUT_DIR, f"dep_Ey_{protein_names[j]}_consistent.pdf")
    plt.savefig(fname, format="pdf")
    plt.close()

print(f"Saved {len(dep_list)} dependence plots to: {OUT_DIR}")
print("\n✅ SHAP analysis completed (plots & report only, no training performed).")