In [None]:
# Cell 7: plot compute/align SHAP output, initialize JS if available, render force plot with matplotlib fallback
import numpy as np
import shap
import matplotlib.pyplot as plt
import pandas as pd

# Parameters
row_idx = 0
class_idx = 1  # change if you want a different class index for multiclass outputs

# Preconditions (these names must exist in the session):
# - explainer         : fitted shap.Explainer / TreeExplainer
# - X_mis_trans       : transformed numeric array for misclassified examples (n_samples, n_trans_features)
# - X_train_trans     : transformed numeric array for training set (used for feature-dim inference)
# - trans_feature_names: list of transformed feature names (len == n_trans_features)
# - clf               : trained classifier (used only as fallback for expected/base value)
# Optional:
# - sv or shap_values : output from explainer.shap_values(X_mis_trans) or explainer(X_mis_trans)
# - misclassified     : original DataFrame for raw-feature display in the original script

# 1) Compute SHAP output if not present
if 'sv' not in globals() and 'shap_values' not in globals() and 'expl_all' not in globals():
    sv = explainer.shap_values(X_mis_trans)
else:
    sv = globals().get('sv', globals().get('shap_values', globals().get('expl_all', None)))

# 2) Align SHAP output into a 2-D array `sm` with shape (n_samples, n_features)
arr = np.asarray(sv)
n_samples = X_mis_trans.shape[0]
n_features = X_train_trans.shape[1]

if arr.ndim == 2:
    # already (n_samples, n_features)
    sm = arr
elif arr.ndim == 3:
    s0, s1, s2 = arr.shape
    # common layouts: (n_samples, n_features, n_classes) or (n_samples, n_classes, n_features) or (n_classes, n_samples, n_features)
    if s0 == n_samples and s1 == n_features:
        # (n_samples, n_features, n_classes) -> select class axis last
        sm = arr[:, :, class_idx]
    elif s0 == n_samples and s2 == n_features:
        # (n_samples, n_classes, n_features) -> select class axis middle
        sm = arr[:, class_idx, :]
    elif s1 == n_samples and s2 == n_features:
        # (n_classes, n_samples, n_features) -> select class axis first
        sm = arr[class_idx, :, :]
    else:
        # last-resort reshape attempt
        reshaped = arr.reshape((n_samples, -1))
        if reshaped.shape[1] == n_features:
            sm = reshaped
        else:
            raise ValueError(f"Unhandled SHAP array shape {arr.shape} for expected {(n_samples, n_features)}")
else:
    raise ValueError(f"Unhandled SHAP array ndim {arr.ndim}")

if not (sm.ndim == 2 and sm.shape[1] == n_features):
    raise ValueError(f"After alignment, SHAP shape is {sm.shape} but expected {(n_samples, n_features)}")

# 3) Determine scalar expected/base value for the chosen class
ev = getattr(explainer, "expected_value", None)
if isinstance(ev, (list, tuple, np.ndarray)):
    base_value = float(ev[class_idx]) if len(ev) > 1 else float(ev[0])
elif ev is not None:
    base_value = float(ev)
else:
    probs = clf.predict_proba(X_train_trans)[:, class_idx]
    base_value = float(probs.mean())

# 4) Prepare instance data and shap vector for selected row
instance_data = X_mis_trans[row_idx]
instance_data = instance_data.toarray().ravel() if hasattr(instance_data, "toarray") else np.asarray(instance_data).ravel()
shap_vals_vec = sm[row_idx]
if shap_vals_vec.ndim != 1 or shap_vals_vec.shape[0] != n_features:
    raise ValueError("SHAP values vector shape mismatch with feature dimension")

# 5) Try to initialize JS for interactive force_plot; if unavailable, fall back to matplotlib waterfall
try:
    shap.initjs()
    # interactive JS force plot (uncomment the `display(...)` line when running inside a JS-enabled notebook)

    #Learners to type
    force_obj = 
    display(force_obj)
    
except Exception:
    # fallback: render a static waterfall plot that works in all environments
    expl_single = shap.Explanation(values=shap_vals_vec, base_values=base_value, data=instance_data, feature_names=trans_feature_names)
    shap.plots.waterfall(expl_single)
    plt.show()
