# Week 5–6 Explainability Tests

Interactive companion to the CLI runner introduced in Week 5–6. Use this notebook to iterate on SHAP/LIME settings, capture notebook-native outputs, and sanity-check feature attribution trends before pushing artefacts to `results/explainability/`.

## Workflow Overview

1. The setup cell below pins the project root on `sys.path`, imports the explainability helpers from `src/explainability.py`, and ensures output directories exist.
2. `run_notebook_explainability(...)` mirrors `python -m src.explainability` but surfaces progress bars via `tqdm` so you can monitor model-level operations, selected local examples, and SHAP Kernel iterations.
3. Execute the "Run notebook explainability" cell to kick off the pipeline; tweak the parameters (dataset split, sample sizes, kernel samples) as needed to trade off runtime vs. fidelity.
4. Use the optional helper to inspect the per-model top-feature CSVs directly in the notebook. This is handy when vetting clinician-facing narratives or deciding on threshold experiments.

> **Note:** The notebook writes artefacts to the same locations as the CLI script, so reruns will refresh PNG/HTML assets under `results/explainability/`.

In [1]:
import os
import sys
from pathlib import Path

import pandas as pd

try:
    from tqdm.auto import tqdm
except ImportError:  # pragma: no cover
    tqdm = None

PROJECT_ROOT = Path.cwd().resolve().parents[0]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

from src import explainability as xai


  from .autonotebook import tqdm as notebook_tqdm


Project root: /Users/peter/Desktop/health_xai_project


In [2]:
from datetime import datetime
from typing import Dict, List

import numpy as np
from IPython.display import Markdown, display


def _get_tqdm(iterable, **kwargs):
    if tqdm is None:
        for item in iterable:
            yield item
    else:
        yield from tqdm(iterable, **kwargs)


def _expand_local_indices(probs: np.ndarray, base_indices: List[int], target_count: int) -> List[int]:
    indices = list(base_indices)
    if len(indices) >= target_count:
        return indices
    ranked = np.argsort(probs)[::-1]
    for idx in ranked:
        if idx not in indices:
            indices.append(idx)
        if len(indices) >= target_count:
            break
    return indices


def run_notebook_explainability(
    dataset: str = "validation",
    sample_size: int = 120,
    background_size: int = 35,
    kernel_nsamples: int = 80,
    lime_instances: int = 2,
    random_state: int = 42,
) -> pd.DataFrame:
    """Execute the explainability pipeline with visible progress bars."""

    splits = xai.load_splits()
    X_train = splits["X_train"]
    dataset_prefix = "val" if dataset == "validation" else "test"
    X_target = splits[f"X_{dataset_prefix}"]

    sample_size = min(sample_size, len(X_target))
    sample_df = X_target.sample(n=sample_size, random_state=random_state).copy().reset_index(drop=True)

    models, scaler = xai.load_models(input_dim=X_train.shape[1], include_tuned=True)

    lime_explainer = xai.build_lime_explainer(
        X_train,
        feature_names=list(X_train.columns),
        categorical_features=xai.infer_categorical_indices(X_train.columns),
    )

    kernel_background = X_train.sample(n=min(background_size, len(X_train)), random_state=random_state)

    summary_records: List[Dict[str, object]] = []
    model_iter = list(xai.MODEL_CONFIGS)

    for config in _get_tqdm(model_iter, desc="Models", unit="model"):
        model = models.get(config.model_key)
        if model is None:
            display(Markdown(f"⚠️ **{config.model_key}** not found. Skipping."))
            continue

        model_dir = xai.ensure_directory(xai.EXPLAINABILITY_DIR / config.pretty_name)
        predict_fn = xai.make_predict_function(config.model_key, model, scaler, X_train.columns)
        probs = predict_fn(sample_df)[:, 1]
        base_indices = xai.select_local_indices(probs)
        selected_indices = _expand_local_indices(probs, base_indices, lime_instances)
        instance_rows = [sample_df.iloc[idx] for idx in selected_indices]
        instance_labels = [f"idx{idx}_p{probs[idx]:.2f}" for idx in selected_indices]

        if config.shap_method == "tree":
            shap_values, expected_value = xai.generate_tree_shap(
                config.model_key,
                model,
                sample_df,
                model_dir,
            )
        else:
            shap_values, expected_value = xai.generate_kernel_shap(
                config.model_key,
                predict_fn,
                kernel_background,
                sample_df,
                nsamples=kernel_nsamples,
                output_dir=model_dir,
            )

        importance = np.mean(np.abs(shap_values), axis=0)
        importance_series = pd.Series(importance, index=sample_df.columns).sort_values(ascending=False)
        top_features_path = model_dir / f"{config.model_key}_top_features.csv"
        importance_series.to_csv(top_features_path, header=["mean_abs_shap"])

        force_paths = []
        for idx, label in _get_tqdm(list(zip(selected_indices, instance_labels)), desc=f"Force plots → {config.pretty_name}", leave=False):
            force_path = xai.save_shap_force_plot(
                config.model_key,
                expected_value,
                shap_values[idx],
                sample_df.iloc[idx],
                model_dir,
                label,
            )
            force_paths.append(force_path)

        lime_paths = xai.save_lime_explanations(
            config.model_key,
            lime_explainer,
            predict_fn,
            instance_rows,
            instance_labels,
            model_dir,
        )

        summary_records.append(
            {
                "model": config.pretty_name,
                "dataset": dataset,
                "sample_size": sample_size,
                "lime_examples": ", ".join(Path(path).name for path in lime_paths),
                "shap_force_examples": ", ".join(path.name for path in force_paths),
                "top_features_csv": top_features_path.name,
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            }
        )

    summary_df = pd.DataFrame(summary_records)
    if not summary_df.empty:
        summary_path = xai.EXPLAINABILITY_DIR / f"xai_summary_{dataset}_notebook.csv"
        summary_df.to_csv(summary_path, index=False)
        display(Markdown(f"✅ Notebook artefact manifest saved to **{summary_path}**"))
    else:
        display(Markdown("⚠️ No artefacts generated. Check model availability."))

    return summary_df


def show_top_features(model_dir: str, top_k: int = 10) -> None:
    model_path = xai.EXPLAINABILITY_DIR / model_dir
    csv_path = next(model_path.glob("*_top_features.csv"), None)
    if csv_path is None:
        display(Markdown(f"⚠️ No top-feature CSV found under `{model_path}`."))
        return
    df = pd.read_csv(csv_path).head(top_k)
    display(Markdown(f"### Top {top_k} features · {model_dir}"))
    display(df)


In [3]:
# Run explainability with default parameters (≈30 seconds on laptop hardware)
summary_df = run_notebook_explainability(
    dataset="validation",
    sample_size=120,
    background_size=35,
    kernel_nsamples=80,
    lime_instances=2,
    random_state=42,
)
summary_df

[INFO] Loading data splits from /Users/peter/Desktop/health_xai_project/results/models/data_splits.joblib


Models:   0%|          | 0/3 [00:00<?, ?model/s]

[SHAP] Saved summary plots to /Users/peter/Desktop/health_xai_project/results/explainability/RandomForest_Tuned/random_forest_tuned_shap_summary_dot.png and /Users/peter/Desktop/health_xai_project/results/explainability/RandomForest_Tuned/random_forest_tuned_shap_summary_bar.png


Models:  33%|███▎      | 1/3 [00:03<00:07,  3.62s/model]

[SHAP] Saved summary plots to /Users/peter/Desktop/health_xai_project/results/explainability/XGBoost_Tuned/xgboost_tuned_shap_summary_dot.png and /Users/peter/Desktop/health_xai_project/results/explainability/XGBoost_Tuned/xgboost_tuned_shap_summary_bar.png


100%|██████████| 120/120 [00:00<00:00, 135.41it/s]model]


[SHAP] Saved summary plots to /Users/peter/Desktop/health_xai_project/results/explainability/NeuralNetwork_Tuned/neural_network_tuned_shap_summary_dot.png and /Users/peter/Desktop/health_xai_project/results/explainability/NeuralNetwork_Tuned/neural_network_tuned_shap_summary_bar.png


Models: 100%|██████████| 3/3 [00:09<00:00,  3.23s/model]


✅ Notebook artefact manifest saved to **/Users/peter/Desktop/health_xai_project/results/explainability/xai_summary_validation_notebook.csv**

Unnamed: 0,model,dataset,sample_size,lime_examples,shap_force_examples,top_features_csv,timestamp
0,RandomForest_Tuned,validation,120,"random_forest_tuned_lime_idx0_p0.54.html, rand...","random_forest_tuned_force_idx0_p0.54.png, rand...",random_forest_tuned_top_features.csv,2025-11-16 23:35:32
1,XGBoost_Tuned,validation,120,"xgboost_tuned_lime_idx5_p0.76.html, xgboost_tu...","xgboost_tuned_force_idx5_p0.76.png, xgboost_tu...",xgboost_tuned_top_features.csv,2025-11-16 23:35:35
2,NeuralNetwork_Tuned,validation,120,"neural_network_tuned_lime_idx0_p0.61.html, neu...","neural_network_tuned_force_idx0_p0.61.png, neu...",neural_network_tuned_top_features.csv,2025-11-16 23:35:38


<Figure size 900x250 with 0 Axes>

<Figure size 900x250 with 0 Axes>

<Figure size 900x250 with 0 Axes>

<Figure size 900x250 with 0 Axes>

<Figure size 900x250 with 0 Axes>

<Figure size 900x250 with 0 Axes>

In [4]:
# Inspect SHAP top features for a given tuned model
show_top_features("RandomForest_Tuned")
show_top_features("XGBoost_Tuned")
show_top_features("NeuralNetwork_Tuned")

### Top 10 features · RandomForest_Tuned

Unnamed: 0.1,Unnamed: 0,mean_abs_shap
0,numeric__health,0.133819
1,numeric__dosprt,0.029173
2,numeric__flteeff,0.02144
3,numeric__slprl,0.019759
4,numeric__weighta,0.011915
5,numeric__fltdpr,0.01176
6,numeric__cgtsmok,0.009912
7,numeric__enjlf,0.007531
8,numeric__alcfreq,0.007081
9,numeric__height,0.007072


### Top 10 features · XGBoost_Tuned

Unnamed: 0.1,Unnamed: 0,mean_abs_shap
0,numeric__health,0.912471
1,numeric__dosprt,0.122126
2,numeric__cgtsmok,0.115077
3,numeric__weighta,0.104653
4,numeric__slprl,0.09297
5,numeric__flteeff,0.081739
6,numeric__height,0.062231
7,numeric__alcfreq,0.049674
8,numeric__happy,0.0489
9,numeric__gndr,0.037742


### Top 10 features · NeuralNetwork_Tuned

Unnamed: 0.1,Unnamed: 0,mean_abs_shap
0,numeric__health,0.136413
1,numeric__height,0.02507
2,numeric__weighta,0.017998
3,numeric__flteeff,0.016795
4,numeric__slprl,0.013102
5,numeric__gndr,0.011269
6,numeric__etfruit,0.010236
7,numeric__happy,0.009575
8,numeric__dosprt,0.0083
9,categorical__cntry_IT,0.007247
