In [None]:
#Cell 6:compute
import numpy as np
import shap
import matplotlib.pyplot as plt
import pandas as pd

# Which example and class to inspect
row_idx = 0                # index of the misclassified sample to explain
class_idx = 1              # for binary tasks: 1 -> positive class (adjust if needed)

# Compute SHAP values for the misclassified rows using the previously created explainer
# explainer must be a TreeExplainer created from the tree estimator (see prior cell)

#Learner To type

print("SHAP computed; type:", type(sv), "shape/info:", getattr(sv, "shape", "list/obj"))

# Normalize SHAP output to a 2D matrix shap_matrix with shape (n_samples, n_transformed_features)
# Different SHAP versions / model types return different shapes (list per class, 2D, or 3D arrays).
if isinstance(sv, list):
    # When sv is a list, pick the contributions for the target class (common for binary/multiclass)
    class_idx = 1 if len(sv) > 1 else 0
    shap_matrix = np.asarray(sv[class_idx])
else:
    sv_arr = np.asarray(sv)
    n_trans_features = X_train_trans.shape[1]

    if sv_arr.ndim == 2:
        # already (n_samples, n_features)
        shap_matrix = sv_arr
    elif sv_arr.ndim == 3:
        # common 3D layouts vary; try to identify the layout and extract a (n_samples, n_features) slice
        s0, s1, s2 = sv_arr.shape
        if s0 == X_mis_trans.shape[0] and s1 == n_trans_features:
            shap_matrix = sv_arr[:, :, class_idx]
        elif s0 == X_mis_trans.shape[0] and s2 == n_trans_features:
            shap_matrix = sv_arr[:, class_idx, :]
        elif s1 == X_mis_trans.shape[0] and s2 == n_trans_features:
            shap_matrix = sv_arr[class_idx, :, :]
        else:
            # Last resort: reshape and hope features align
            reshaped = sv_arr.reshape((X_mis_trans.shape[0], -1))
            if reshaped.shape[1] == n_trans_features:
                shap_matrix = reshaped
            else:
                raise ValueError(f"Unhandled SHAP array shape {sv_arr.shape} for X_mis_trans {X_mis_trans.shape} and {n_trans_features} features")
    else:
        raise ValueError(f"Unhandled SHAP array ndim {sv_arr.ndim}")

# Verify we have a 2D array with expected number of transformed features
n_trans_features = X_train_trans.shape[1]
if shap_matrix.ndim != 2 or shap_matrix.shape[1] != n_trans_features:
    raise ValueError(f"SHAP shape {shap_matrix.shape} incompatible with transformed feature dim {n_trans_features}")

print("Normalized shap_matrix.shape:", shap_matrix.shape)

# Derive transformed feature names to label the SHAP values for plots
try:
    trans_feature_names = prep.get_feature_names_out()
except Exception:
    # Fallback construction: numeric features + one-hot names from categorical transformer if available
    numeric_features = ["age", "tenure_months", "monthly_charges", "num_support_tickets", "has_internet"]
    cat_features = ["contract_type", "payment_method"]
    try:
        cat_names = prep.named_transformers_["cat"].get_feature_names_out(cat_features)
        trans_feature_names = list(numeric_features) + list(cat_names)
    except Exception:
        # Final fallback: positional feature names (f0, f1, ...)
        trans_feature_names = [f"f{i}" for i in range(n_trans_features)]

print("Number of transformed features:", len(trans_feature_names))

# Extract per-row SHAP values for the selected row and collapse shapes if needed
row_vals = np.asarray(shap_matrix[row_idx])
if row_vals.ndim == 2:
    if row_vals.shape[0] == len(trans_feature_names) and row_vals.shape[1] <= 5:
        # layout: (n_features, n_classes) -> take column for class_idx
        row_vals = row_vals[:, class_idx]
    elif row_vals.shape[1] == len(trans_feature_names) and row_vals.shape[0] <= 5:
        # layout: (n_classes, n_features) -> take row for class_idx
        row_vals = row_vals[class_idx]
    else:
        # flatten and trim/pad to feature count
        row_vals = row_vals.ravel()[: len(trans_feature_names)]

# Determine base (expected) value for the explainer output
ev = getattr(explainer, "expected_value", None)
if isinstance(ev, (list, tuple, np.ndarray)):
    base_val = float(ev[class_idx]) if len(ev) > 1 else float(ev[0])
elif ev is not None:
    base_val = float(ev)
else:
    # Fallback: use mean predicted probability on training data for the chosen class
    probs = clf.predict_proba(X_train_trans)[:, class_idx]
    base_val = float(probs.mean())

# Build a SHAP Explanation object for the selected row (used by SHAP plotting functions)
expl = shap.Explanation(
    values=row_vals,
    base_values=base_val,
    data=X_mis_trans[row_idx],
    feature_names=trans_feature_names
)

# Render a waterfall plot for the selected explanation
shap.plots.waterfall(expl)
plt.show()
