# 03 - Explainability with SHAP

Compute SHAP values for trained models and generate a Mistral-backed narrative.


In [None]:
import sys
from pathlib import Path
import pandas as pd

# Ensure the repository root and src directory are on sys.path before importing project modules
BASE_DIR = Path("..").resolve()
sys.path.insert(0, str(BASE_DIR))
sys.path.insert(0, str(BASE_DIR / "src"))

# Import and reload project modules to ensure we pick up recent edits
import importlib
from src import data_pipeline, model_pipeline, explainability, visualization
importlib.reload(data_pipeline)
importlib.reload(model_pipeline)
importlib.reload(explainability)
importlib.reload(visualization)

# Also ensure any legacy top-level module names (e.g. 'nlp_layer') map to the
# package module so older cached modules can find updated helpers.
import src.nlp_layer as nlp_layer_pkg
importlib.reload(nlp_layer_pkg)
sys.modules['nlp_layer'] = nlp_layer_pkg

# Prepare a single sample and run prediction + explainability
X, y = data_pipeline.prepare_training_data()
sample = X.sample(1, random_state=42)
result = model_pipeline.predict_single("random_forest", sample.iloc[0].to_dict())

# Debug prints to inspect types/shapes which are the common root cause when plotting fails
print("predict_single result keys:", list(result.keys()))
print("shap_values type:", type(result.get("shap_values")))
try:
    print("shap_values length:", len(result.get("shap_values")))
except Exception as _:
    print("shap_values length: could not determine (non-sequence)")
print("feature_names type:", type(result.get("feature_names")))
print("feature_names length:", len(result.get("feature_names")))

# Build explanation object
explanation = explainability.build_explanation(
    risk_level=result["prediction"],
    probabilities=result["probabilities"],
    shap_values=result["shap_values"],
    feature_names=result["feature_names"],
)

print("Explanation.top_features:", explanation.top_features)

# Try to save and display the SHAP summary inline in the notebook. If plotting fails,
# attempt a flattened fallback to demonstrate the root cause.
from IPython.display import Image, display
import numpy as np

try:
    path = visualization.plot_shap_summary(explanation.shap_values, explanation.feature_names, "notebook_shap")
    print("Saved plot to:", path)
    display(Image(str(path)))
except Exception as e:
    print("Error plotting shap summary:", repr(e))
    # Attempt to flatten nested or oddly-shaped inputs and retry
    try:
        flat = np.ravel(explanation.shap_values)[: len(explanation.feature_names)]
        print("Retrying with flattened shap values (shape):", np.asarray(flat).shape)
        path2 = visualization.plot_shap_summary(flat, explanation.feature_names, "notebook_shap_flat")
        print("Saved flat plot to:", path2)
        display(Image(str(path2)))
    except Exception as e2:
        print("Retry failed:", repr(e2))

# Finally show the narrative text
print("\nNarrative:\n", explanation.narrative)


'[Mocked Mistral Response] Based on the highlighted drivers, the supplier shows elevated risk due to persistent payment delays and clause risk.'