In [None]:
# Cell 6: compute SHAP values, normalize to shap_matrix, build Explanation and plot waterfall
import numpy as np
import shap
import matplotlib.pyplot as plt
import pandas as pd

row_idx = 0
class_idx = 1

#compute SHAP values
#Learners to type

print("SHAP computed; type:", type(sv), "shape/info:", getattr(sv, 'shape', 'list/obj'))

if isinstance(sv, list):
    class_idx = 1 if len(sv) > 1 else 0
    shap_matrix = np.asarray(sv[class_idx])
else:
    sv_arr = np.asarray(sv)
    n_trans_features = X_train_trans.shape[1]

    if sv_arr.ndim == 2:
        shap_matrix = sv_arr
    elif sv_arr.ndim == 3:
        s0, s1, s2 = sv_arr.shape
        if s0 == X_mis_trans.shape[0] and s1 == n_trans_features:
            shap_matrix = sv_arr[:, :, class_idx]
        elif s0 == X_mis_trans.shape[0] and s2 == n_trans_features:
            shap_matrix = sv_arr[:, class_idx, :]
        elif s1 == X_mis_trans.shape[0] and s2 == n_trans_features:
            shap_matrix = sv_arr[class_idx, :, :]
        else:
            reshaped = sv_arr.reshape((X_mis_trans.shape[0], -1))
            if reshaped.shape[1] == n_trans_features:
                shap_matrix = reshaped
            else:
                raise ValueError(f"Unhandled SHAP array shape {sv_arr.shape} for {X_mis_trans.shape} and {n_trans_features} features")
    else:
        raise ValueError(f"Unhandled SHAP array ndim {sv_arr.ndim}")

n_trans_features = X_train_trans.shape[1]
if shap_matrix.ndim != 2 or shap_matrix.shape[1] != n_trans_features:
    raise ValueError(f"SHAP shape {shap_matrix.shape} incompatible with transformed feature dim {n_trans_features}")

print("Normalized shap_matrix.shape:", shap_matrix.shape)

try:
    trans_feature_names = prep.get_feature_names_out()
except Exception:
    numeric_features = ["age","tenure_months","monthly_charges","num_support_tickets","has_internet"]
    cat_features = ["contract_type","payment_method"]
    try:
        cat_names = prep.named_transformers_["cat"].get_feature_names_out(cat_features)
        trans_feature_names = list(numeric_features) + list(cat_names)
    except Exception:
        trans_feature_names = [f"f{i}" for i in range(n_trans_features)]

print("Number of transformed features:", len(trans_feature_names))

row_vals = np.asarray(shap_matrix[row_idx])
if row_vals.ndim == 2:
    if row_vals.shape[0] == len(trans_feature_names) and row_vals.shape[1] <= 5:
        row_vals = row_vals[:, class_idx]
    elif row_vals.shape[1] == len(trans_feature_names) and row_vals.shape[0] <= 5:
        row_vals = row_vals[class_idx]
    else:
        row_vals = row_vals.ravel()[: len(trans_feature_names)]

ev = getattr(explainer, "expected_value", None)
if isinstance(ev, (list, tuple, np.ndarray)):
    base_val = float(ev[class_idx]) if len(ev) > 1 else float(ev[0])
elif ev is not None:
    base_val = float(ev)
else:
    probs = clf.predict_proba(X_train_trans)[:, class_idx]
    base_val = float(probs.mean())

expl = shap.Explanation(values=row_vals,
                        base_values=base_val,
                        data=X_mis_trans[row_idx],
                        feature_names=trans_feature_names)

shap.plots.waterfall(expl)
plt.show()
