In [5]:
# predictor.py
# Crash-proof SHAP + CatBoost inference with single-row inputs
# - Uses CatBoost native SHAP (get_feature_importance(..., type="ShapValues"))
# - Forces non-GUI matplotlib backend
# - Avoids shap.TreeExplainer and summary_plot (common segfault sources)
# - Handles binary/multiclass/regressor shapes robustly

import os
import pprint
import numpy as np
import pandas as pd
import joblib

# --- Force non-GUI backend BEFORE importing pyplot ---
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from catboost import CatBoostClassifier, CatBoostRegressor, Pool
import catboost

# Optional: keep import; we’re not using shap’s explainer, only version print if needed
import shap

# ------------------------------
# Load All Models and Metadata
# ------------------------------

def load_model_and_meta(model_path, meta_path):
    is_classifier = any(k in model_path.lower() for k in ["classifier", "response", "type"])
    model = CatBoostClassifier() if is_classifier else CatBoostRegressor()
    model.load_model(model_path)
    meta = joblib.load(meta_path)
    return model, meta

model_type, meta_type = load_model_and_meta(
    "./saved_model/anesthesia_type_model_catBoost_upsampled.cbm",
    "./saved_model/model_metadata.pkl"
)

model_dosage, meta_dosage = load_model_and_meta(
    "./saved_model/general_dosage_model.cbm",
    "./saved_model/general_dosage_metadata.pkl"
)

model_response, meta_response = load_model_and_meta(
    "./saved_model/general_response_confidence_model.cbm",
    "./saved_model/general_response_metadata_confidence.pkl"
)

# ------------------------------
# SHAP Explanation Helpers
# ------------------------------
from feature_importance_map import (
    feature_reference_map_anesthesia_type,
    dosages_feature_reference_map,
    feature_reference_map_response
)

def explain_type_reason(feature, value, shap_val):
    direction = "positive" if shap_val > 0 else "negative"
    base_name = feature.split("_")[0]
    ref_data = feature_reference_map_anesthesia_type.get(base_name, {})
    template = ref_data.get(direction, f"{feature} ({value}) influenced anesthesia type.")
    reference = ref_data.get("reference", "No reference available.")
    return f"{template.format(value=value)}\nRef: {reference}"

def explain_dosage_reason(feature, value, shap_val):
    direction = "positive" if shap_val > 0 else "negative"
    base_name = feature.split("_")[0]
    ref_data = dosages_feature_reference_map.get(base_name, {})
    template = ref_data.get(direction, f"{feature} ({value}) influenced the dose.")
    reference = ref_data.get("reference", "No reference available.")
    return f"{template.format(value=value)}\nRef: {reference}"

def explain_response_reason(feature, value, shap_val):
    direction = "positive" if shap_val > 0 else "negative"
    base_name = feature.split("_")[0]
    ref_data = feature_reference_map_response.get(base_name, {})
    template = ref_data.get(direction, f"{feature} ({value}) influenced response.")
    reference = ref_data.get("reference", "No reference available.")
    return f"{template.format(value=value)}\nRef: {reference}"

def safe_extract(arr, *idx, default=0.0):
    try:
        return float(arr[idx])
    except Exception:
        return float(default)

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def barplot_shap(name: str, features: list, shap_vec: np.ndarray, topk: int = 15):
    """Save a simple bar chart for a single-sample SHAP vector."""
    ensure_dir("shap_plots")
    if shap_vec is None or len(shap_vec) == 0:
        return
    shap_vec = np.asarray(shap_vec).astype(float)
    k = min(topk, len(shap_vec))
    idxs = np.argsort(np.abs(shap_vec))[-k:][::-1]
    labels = [features[i] for i in idxs]
    vals = shap_vec[idxs]

    plt.figure()
    plt.title(f"Top feature impacts – {name}")
    plt.bar(range(len(vals)), vals)
    plt.xticks(range(len(vals)), labels, rotation=90)
    plt.tight_layout()
    plt.savefig(f"shap_plots/{name.lower().replace(' ', '_')}_shap.png", bbox_inches="tight")
    plt.close()

def catboost_shap_values(model, X: pd.DataFrame):
    """
    Use CatBoost native SHAP to avoid shap.TreeExplainer segfault paths.
    Returns:
      - For regressors/binary classifier: [n_samples, n_features+1]
      - For multiclass: [n_samples, n_classes, n_features+1]
    Last column is expected value (base value); first n_features are feature contributions.
    """
    pool = Pool(X)
    shap_vals = model.get_feature_importance(data=pool, type="ShapValues")
    return np.array(shap_vals)

def get_predicted_class_index(labels: np.ndarray, probas: np.ndarray, threshold: float = 0.4):
    """Returns (top_idx, final_class_label) with threshold logic mirrored from user code."""
    top_idx = int(np.argmax(probas))
    top_prob = float(probas[top_idx])
    if top_prob >= threshold:
        return top_idx, labels[top_idx]
    else:
        # pick second-best class when top prob < threshold
        order = np.argsort(probas)
        second_idx = int(order[-2])
        return second_idx, labels[second_idx]

# ------------------------------
# Predict Function
# ------------------------------

def predict_all(input_data: pd.DataFrame, explain: bool = False):
    # Strictly one row
    assert input_data.shape[0] == 1, "Only one row of input is expected."

    # Guard: verify model feature availability
    for tag, meta in [
        ("Type", meta_type),
        ("Dosage", meta_dosage),
        ("Response", meta_response),
    ]:
        missing = [f for f in meta["features"] if f not in input_data.columns]
        assert not missing, f"{tag} model features missing from input: {missing}"

    patient_data = input_data.to_dict(orient="records")[0]

    # ---- TYPE PREDICTION ----
    raw_pred = model_type.predict(input_data[meta_type["features"]])[0]
    if isinstance(raw_pred, np.ndarray) and raw_pred.size > 1:
        if "classes" in meta_type:
            type_idx = int(np.argmax(raw_pred))
            type_pred = meta_type["classes"][type_idx]
        else:
            type_idx = int(np.argmax(raw_pred))
            anesthesia_classes = ["Type1", "Type2", "Type3", "Type4", "Type5", "Type6"]
            type_pred = anesthesia_classes[type_idx] if type_idx < len(anesthesia_classes) else f"Type{type_idx+1}"
    else:
        type_pred = raw_pred
        # best-effort index for SHAP selection later
        try:
            # if model has predict_proba, align index with highest prob
            if hasattr(model_type, "predict_proba"):
                probs_type = model_type.predict_proba(input_data[meta_type["features"]])[0]
                type_idx = int(np.argmax(probs_type))
            else:
                type_idx = 0
        except Exception:
            type_idx = 0

    # ---- DOSAGE (REGRESSION) ----
    dosage_pred = float(model_dosage.predict(input_data[meta_dosage["features"]])[0])

    # ---- RESPONSE (CLASSIFIER) ----
    probas = model_response.predict_proba(input_data[meta_response["features"]])[0]
    labels = model_response.classes_
    top_idx, final_response = get_predicted_class_index(labels, probas, threshold=0.4)

    result = {
        "General_AnesthesiaType": str(type_pred),
        "General_Dosage": round(dosage_pred, 2),
        "General_Response": str(final_response),
        "Patient_Profile": patient_data
    }

    if not explain:
        return result

    explanations = {}
    ensure_dir("shap_plots")

    # ------------------------------
    # SHAP: TYPE (native CatBoost)
    # ------------------------------
    shap_stats_type = {}
    try:
        subset = input_data[meta_type["features"]]
        sv = catboost_shap_values(model_type, subset)
        # Shapes:
        # - Binary/regression-like: [1, F+1]
        # - Multiclass: [1, C, F+1]
        if sv.ndim == 3:
            # [n_samples, n_classes, n_features+1] => take sample 0
            sv_cls = sv[0]                         # [C, F+1]
            base_vals = sv_cls[:, -1]              # [C]
            shap_feats_by_class = sv_cls[:, :-1]   # [C, F]
            shap_for_pred = shap_feats_by_class[top_idx]  # [F]
            # per-class export
            for cls_i in range(shap_feats_by_class.shape[0]):
                cls_name = meta_type.get("classes", [f"Class_{i}" for i in range(shap_feats_by_class.shape[0])])[cls_i] \
                           if "classes" in meta_type else f"Class_{cls_i}"
                shap_stats_type[cls_name] = {
                    f: float(shap_feats_by_class[cls_i, fi]) for fi, f in enumerate(meta_type["features"])
                }
        else:
            # [1, F+1]
            sv1 = sv[0]                 # [F+1]
            base_val = sv1[-1]
            shap_for_pred = sv1[:-1]    # [F]

        # Human-readable per-feature for predicted class (or sole vector)
        type_explanations = []
        for i, f in enumerate(meta_type["features"]):
            val = input_data[f].values[0]
            s = float(shap_for_pred[i])
            type_explanations.append(f"{explain_type_reason(f, val, s)} (SHAP: {s:.4f})")

        barplot_shap("Anesthesia Type", meta_type["features"], shap_for_pred, topk=15)

        explanations["General_AnesthesiaType"] = {
            "explanations": type_explanations,
            "per_class_shap_values": shap_stats_type  # may be empty if not multiclass
        }
    except Exception as e:
        explanations["General_AnesthesiaType"] = [f"Type SHAP error: {str(e)}"]

    # ------------------------------
    # SHAP: DOSAGE (regressor, native CatBoost)
    # ------------------------------
    try:
        subset = input_data[meta_dosage["features"]]
        sv = catboost_shap_values(model_dosage, subset)  # [1, F+1]
        sv1 = sv[0]
        shap_vec = sv1[:-1]  # drop base value
        dosage_explanations = []
        for i, f in enumerate(meta_dosage["features"]):
            val = input_data[f].values[0]
            s = float(shap_vec[i])
            dosage_explanations.append(f"{explain_dosage_reason(f, val, s)} (SHAP: {s:.4f})")

        barplot_shap("Dosage Prediction", meta_dosage["features"], shap_vec, topk=15)

        explanations["General_Dosage"] = {
            "explanations": dosage_explanations
        }
    except Exception as e:
        explanations["General_Dosage"] = [f"Dosage SHAP error: {str(e)}"]

    # ------------------------------
    # SHAP: RESPONSE (classifier, native CatBoost)
    # ------------------------------
    shap_stats_response = {}
    try:
        subset = input_data[meta_response["features"]]
        sv = catboost_shap_values(model_response, subset)
        # Binary: [1, F+1]; Multiclass: [1, C, F+1]
        if sv.ndim == 3:
            sv_cls = sv[0]                       # [C, F+1]
            shap_feats_by_class = sv_cls[:, :-1] # [C, F]
            shap_for_pred = shap_feats_by_class[top_idx]  # [F]
            # per-class dump
            for cls_i in range(shap_feats_by_class.shape[0]):
                cls_name = str(labels[cls_i]) if cls_i < len(labels) else f"Class_{cls_i}"
                shap_stats_response[cls_name] = {
                    f: float(shap_feats_by_class[cls_i, fi]) for fi, f in enumerate(meta_response["features"])
                }
        else:
            sv1 = sv[0]              # [F+1]
            shap_for_pred = sv1[:-1] # [F]

        response_explanations = []
        for i, f in enumerate(meta_response["features"]):
            val = input_data[f].values[0]
            s = float(shap_for_pred[i])
            response_explanations.append(f"{explain_response_reason(f, val, s)} (SHAP: {s:.4f})")

        barplot_shap("Response Prediction", meta_response["features"], shap_for_pred, topk=15)

        explanations["General_Response"] = {
            "explanations": response_explanations,
            "per_class_shap_values": shap_stats_response
        }
    except Exception as e:
        explanations["General_Response"] = [f"Response SHAP error: {str(e)}"]

    result["SHAP_Explanations"] = explanations

    # ------------------------------
    # SHAP Influence Summary by Feature
    # ------------------------------
    try:
        # Build per-feature influence map using the vectors we already computed above when possible.
        # If a section failed, fall back to 0.0.
        feature_wise_summary = {}
        all_features = set(meta_type["features"]) | set(meta_dosage["features"]) | set(meta_response["features"])

        # Extract vectors if populated; default to zeros of appropriate length
        type_vec = None
        resp_vec = None
        dose_vec = None

        # Pull from explanations (not super clean but keeps single-pass logic)
        try:
            # we saved a bar-plot using shap_for_pred; we can reconstruct by matching text if needed,
            # but simpler is to recompute quickly here (cheap for single-row)
            sv_type = catboost_shap_values(model_type, input_data[meta_type["features"]])
            if sv_type.ndim == 3:
                type_vec = sv_type[0][type_idx][:-1]
            else:
                type_vec = sv_type[0][:-1]
        except Exception:
            type_vec = np.zeros(len(meta_type["features"]), dtype=float)

        try:
            sv_resp = catboost_shap_values(model_response, input_data[meta_response["features"]])
            if sv_resp.ndim == 3:
                resp_vec = sv_resp[0][top_idx][:-1]
            else:
                resp_vec = sv_resp[0][:-1]
        except Exception:
            resp_vec = np.zeros(len(meta_response["features"]), dtype=float)

        try:
            sv_dose = catboost_shap_values(model_dosage, input_data[meta_dosage["features"]])
            dose_vec = sv_dose[0][:-1]
        except Exception:
            dose_vec = np.zeros(len(meta_dosage["features"]), dtype=float)

        # Assemble cross-model map
        for f in all_features:
            t_inf = float(type_vec[meta_type["features"].index(f)]) if f in meta_type["features"] else 0.0
            d_inf = float(dose_vec[meta_dosage["features"].index(f)]) if f in meta_dosage["features"] else 0.0
            r_inf = float(resp_vec[meta_response["features"].index(f)]) if f in meta_response["features"] else 0.0
            feature_wise_summary[f] = {
                "General_AnesthesiaType_Influence": t_inf,
                "General_Dosage_Influence": d_inf,
                "General_Response_Influence": r_inf
            }

        result["SHAP_Influence_By_Feature"] = feature_wise_summary
    except Exception:
        # Non-fatal: we already have per-section explanations
        pass

    return result


if __name__ == "__main__":
    # Sanity info dump to quickly spot ABI/version issues when running as script
    import sys
    print("VERSIONS:", {
        "python": sys.version,
        "numpy": np.__version__,
        "shap": getattr(shap, "__version__", "unknown"),
        "catboost": catboost.__version__,
        "matplotlib": matplotlib.__version__,
    })

    # Example run (single row)
    df = pd.read_csv("./anesthesia_dataset_v4.csv").dropna().iloc[[78]]
    out = predict_all(df, explain=True)
    print("\n✅ Prediction Result:")
    pprint.pprint(out, indent=2, width=120)


VERSIONS: {'python': '3.11.0 (main, Oct 24 2022, 18:26:48) [MSC v.1933 64 bit (AMD64)]', 'numpy': '2.2.6', 'shap': '0.48.0', 'catboost': '1.2.8', 'matplotlib': '3.10.6'}

✅ Prediction Result:
{ 'General_AnesthesiaType': "['Desflurane']",
  'General_Dosage': 96.48,
  'General_Response': 'Ineffective',
  'Patient_Profile': { 'ABW': 79.47874320000001,
                       'ALDH2_Genotype': 'G/A',
                       'ASA_Class': 2,
                       'Age': 35,
                       'BMI': 30.367346938775512,
                       'CYP2B6_Type': 'NM',
                       'CYP2C9_Type': '*2/*3',
                       'CYP2D6_Type': 'UM',
                       'CYP3A4_Type': 'UM',
                       'CardiovascularHistory': 'No History',
                       'CurrentMedications': 'Other',
                       'Diabetes': 'No',
                       'Diet': 'Vegetarian',
                       'F5_Variant': 'No Mutation',
                       'GABRA2_Variant': 'Nor