# Explainability Lab
**Date:** 2025-10-29

**What this notebook covers**
- Load & preprocess the UCI Adult dataset (via `fetch_openml`).
- Train a Random Forest Classifier.
- Generate and visualize feature attributions with **SHAP** and **Lime**.
- Train a PyTorch MLP classifier.
- Generate and visualize feature attributions using **Captum** gradient-based methods:
  - Saliency
  - SmoothGrad (NoiseTunnel)
  - InputxGradients
  - Integrated Gradients.
- Generate and visualize feature attributions using **Zennit** LRP methods:
  - LRP (Layer-wise Relevance Propagation).
- Evaluate every XAI methods with **Quantus** metrics.
- Implement a counterfactual generation for tabular data using **Dice**.

> Notes:
- This notebook expects an environment with internet to fetch the dataset and `captum` installed.
- Install tips (if needed): `pip install captum lime shap zennit scikit-learn pandas matplotlib quantus`.

## Setup & Imports

In [None]:
import os
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler, OrdinalEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.ensemble import RandomForestClassifier

from captum.attr import (
    Saliency,
    IntegratedGradients,
    NoiseTunnel,
    InputXGradient,
)

from zennit.composites import EpsilonPlusFlat
from zennit.attribution import Gradient

import quantus

from typing import List, Callable, Dict, Tuple


In [None]:
# ---- 1) Global seeds
SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # safe even if no GPU

## Load & Preprocess the Adult dataset

In [None]:
# Fetch Adult from OpenML
adult = fetch_openml(name='adult', version=2, as_frame=True)
df = adult.frame.copy()

# Replace '?' with NaN and drop rows with missing (simple approach)
df = df.replace('?', np.nan).dropna()

# Target is 'class': '>50K' or '<=50K' — convert to 0/1
df['class'] = (df['class'] == '>50K').astype(int)

# Identify categorical vs numeric columns
target_col = 'class'
X_df = df.drop(columns=[target_col])
y = df[target_col].values

cat_cols = X_df.select_dtypes(include=['category','object']).columns.tolist()
num_cols = [c for c in X_df.columns if c not in cat_cols]


In [None]:
df.info()

In [None]:
# Build preprocessing pipeline: OrdinalEncoder for categoricals, Standardize numerics
preprocess = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), num_cols),
        # ('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), cat_cols),
        ('cat', OrdinalEncoder(), cat_cols), # We use here OrdinalEncoder to limit the number of features
    ]
)

X_processed = preprocess.fit_transform(X_df)
feature_names_num = num_cols
feature_names_cat = list(preprocess.named_transformers_['cat'].get_feature_names_out(cat_cols))
feature_names_all = feature_names_num + feature_names_cat

# Split into train test
X_train, X_test, y_train, y_test = train_test_split(
    X_processed, y, test_size=0.2, random_state=42, stratify=y
)

In [None]:
# Create the torch tensor

X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32).view(-1,1)
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.float32).view(-1,1)

if (y_train_t.ndim == 1) or (y_train_t.shape[1] == 1):
    y_train_t = torch.column_stack((1 - y_train_t, y_train_t))

In [None]:
# Create the train/test dataloader

train_ds = TensorDataset(X_train_t, y_train_t)
test_ds = TensorDataset(X_test_t, y_test_t)
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)


## Scikit Learn Random Forest Model & Training

In [None]:
rf = RandomForestClassifier(
    n_estimators=300,
    random_state=SEED,
    n_jobs=-1
)
rf.fit(X_train, y_train)

y_prob = rf.predict_proba(X_test)[:, 1]
y_pred = rf.predict(X_test)
#y_pred = (y_prob >= 0.5).astype(int)

print(classification_report(y_test, y_pred, digits=3))
print("ROC-AUC:", round(roc_auc_score(y_test, y_prob), 3))


In [None]:
a = rf.predict(X_test)
a.shape

In [None]:
y_train

### SHAP Explainer

In [None]:
import shap
shap.initjs()


# explain the model's predictions using SHAP
# (same syntax works for LightGBM, CatBoost, scikit-learn, transformers, Spark, etc.)
explainer = shap.Explainer(rf, feature_names=feature_names_all)
shap_values = explainer(X_test[:100])



In [None]:
X_test[:100].shape

In [None]:
y.shape

In [None]:
# visualize the first prediction's explanation
shap.plots.waterfall(shap_values[0, :, 0])

In [None]:
# visualize the first prediction's explanation with a force plot
shap.plots.force(shap_values[0, :, 0])

In [None]:
# visualize all the training set predictions
shap.plots.force(shap_values[:, :, 0])

In [None]:
# summarize the effects of all the features
shap.plots.beeswarm(shap_values[:, :, 0])

In [None]:
shap.plots.bar(shap_values[:, :, 0])

### LIME

In [None]:
from lime.lime_tabular import LimeTabularExplainer


# def predict_proba_raw(X_raw_df):
#     Xp = preprocess.transform(X_raw_df)
#     return rf.predict_proba(Xp)


explainer_lime = LimeTabularExplainer(
    training_data=X_train,
    feature_names=feature_names_all,
    class_names=['<=50K','>50K'],
    categorical_features=feature_names_cat,
    discretize_continuous=True,
    random_state=SEED)

idx = 0
x_raw = X_test[idx]
exp = explainer_lime.explain_instance(
    data_row=np.array(x_raw),
    predict_fn=rf.predict_proba,
    #predict_fn=lambda Xarr: predict_proba_raw(pd.DataFrame(Xarr, columns=X_train.columns)),
    num_features=10)



In [None]:
# Compatibility shim for LIME + modern IPython
from IPython.display import display, HTML

html = exp.as_html()
display(HTML(html))

# # monkey-patch only if missing
# if not hasattr(_icd, "display"):
#     _icd.display = display
# if not hasattr(_icd, "HTML"):
#     _icd.HTML = HTML

# exp.show_in_notebook(show_table=True, show_all=False)


## PyTorch MLP Model & Training

In [None]:
class MLPModel(nn.Module):
    def __init__(self, n_layers, input_dim, hidden_dim, output_dim):
        super().__init__()     

        layers = []

        # Input Layer (= first hidden layer)
        layers += [nn.Linear(input_dim, hidden_dim), nn.ReLU()]

        # Hidden Layers (number specified by n_layers)
        for _ in range(n_layers -1):
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ]

        # Output Layer
        layers += [nn.Linear(hidden_dim, output_dim)]
        self.network = nn.Sequential(*layers)

    def forward(self, x):

        x = self.network(x)
        x = F.softmax(x, dim=1)

        return x


In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        """Reset the class attributes.
        """
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n_count=1):
        """Update the values.

        Args:
            val (float): Current value.
            n_count (int, optional): Number of current value. Defaults to 1.
        """
        self.val = val
        self.sum += val * n_count
        self.count += n_count
        self.avg = self.sum / self.count


In [None]:
def train(model, train_loader, num_epochs, criterion, optimizer, device):
    model.train()
    loss_meter = AverageMeter()
    for epoch in range(num_epochs):
        for i, (x_batch, y_batch) in enumerate(train_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            logits = model(x_batch)
            # print(logits.shape, y_batch.shape)
            loss = criterion(logits, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
        print(f"Epoch {epoch+1}/{num_epochs} | loss={loss_meter.avg:.3f}")


def predict(model, test_loader, device):
    model.eval()
    predictions = []
    # y_true = []
    with torch.no_grad():
        for batch_X, _ in test_loader:
            preds = model(batch_X.to(device))
            # preds = torch.sigmoid(logits)
            predictions.append(preds.detach().cpu().numpy())

            # y_true.append(batch_y.cpu().numpy())

    probas = np.concatenate(predictions)

    # If binary task returns only probability for the true class, adapt it to return (N x 2)
    if probas.shape[1] == 1:
        probas = np.concatenate((1 - probas, probas), 1)

    predictions = np.argmax(probas, axis=1)
    return predictions


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_layers = 4
input_dim = X_train_t.shape[1]
hidden_dim = 47
output_dim = 2 # number of classes
num_epochs = 20

model = MLPModel(n_layers=n_layers, input_dim=input_dim,
                hidden_dim=hidden_dim,
                output_dim=output_dim
                ).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
train(model, train_loader=train_loader,
        num_epochs=num_epochs,
        criterion=criterion, optimizer=optimizer,
        device=device)

In [None]:
pred = predict(model, test_loader, device)

acc = (pred==y_test).mean()
print(f"Accuracy on the test set : {acc*100 :2f}%")

In [None]:
print(classification_report(y_test, pred, digits=3))


In [None]:
y_pred_t = torch.from_numpy(pred).to(device)

## Captum Attributions

In [None]:
# Helper method to print importances and visualize distribution
def visualize_importances(feature_names,
                        importances,
                        title="Average Feature Importances",
                        plot=True,
                        axis_title="Features"):
    print(title)
    for i in range(len(feature_names)):
        print(feature_names[i], ": ", '%.3f'%(importances[i]))
    x_pos = (np.arange(len(feature_names)))
    if plot:
        plt.figure(figsize=(12,6))
        plt.bar(x_pos, importances, align='center')
        plt.xticks(x_pos, feature_names, wrap=True, rotation=45)
        plt.xlabel(axis_title)
        plt.title(title)


def beeswarm_attributions(
    attrs,                 # np.array [n_samples, n_features], signed attributions
    X,                     # np.array [n_samples, n_features], original (preprocessed) feature values
    feature_names,         # list[str] length n_features
    max_display=20,        # how many features to show (top by mean |attr|)
    color_by="feature",    # "feature" (feature values) or "attr" (attribution values) or None
    cmap=None,             # e.g., "coolwarm" or "viridis"; if None uses Matplotlib default
    jitter=0.25,           # vertical jitter scale
    dot_size=8,            # marker size
    title="Beeswarm of Attributions",
    xlabel="Attribution (signed)",
):
    attrs = np.asarray(attrs)
    X = np.asarray(X)
    assert attrs.shape == X.shape, "attrs and X must have same shape [n_samples, n_features]"
    n_samples, n_features = attrs.shape
    feature_names = list(feature_names)

    # Rank features by mean absolute attribution
    mean_abs = np.mean(np.abs(attrs), axis=0)
    order = np.argsort(-mean_abs)[:max_display]
    attrs_sub = attrs[:, order]
    X_sub = X[:, order]
    names_sub = [feature_names[i] for i in order]

    # Prepare figure
    plt.figure(figsize=(10, 0.4 * len(names_sub) + 2))
    y_base = np.arange(len(names_sub))  # one row per feature (top is most important)
    y_plot_positions = []

    # Normalize color reference per-feature (like SHAP)
    def normalize_col(v):
        v = v.astype(float)
        vmin, vmax = np.nanmin(v), np.nanmax(v)
        if vmax == vmin:
            return np.zeros_like(v)  # flat color if constant
        return (v - vmin) / (vmax - vmin)

    for j, (a_col, x_col) in enumerate(zip(attrs_sub.T, X_sub.T)):
        # Jitter to avoid overplotting; more points near 0 should stack, not overlap
        # Use rank-based spread to get a “swarm” feel
        # We place points around y = (len(names_sub)-1 - j) so most important is at top
        y0 = (len(names_sub) - 1 - j)
        # Create a small symmetric jitter using ranks of attribution values
        ranks = a_col.argsort().argsort()  # 0..n-1 ranks
        # Center ranks around 0 and scale
        jitter_offsets = (ranks - np.median(ranks)) / (np.max(ranks) + 1e-9)
        y_vals = y0 + jitter * jitter_offsets
        y_plot_positions.append(y0)

        # Colors
        if color_by == "feature":
            cvals = normalize_col(x_col)
            sc = plt.scatter(a_col, y_vals, s=dot_size, c=cvals, cmap=cmap, alpha=0.8, edgecolors='none')
        elif color_by == "attr":
            cvals = normalize_col(a_col)
            sc = plt.scatter(a_col, y_vals, s=dot_size, c=cvals, cmap=cmap, alpha=0.8, edgecolors='none')
        else:
            sc = plt.scatter(a_col, y_vals, s=dot_size, alpha=0.8, edgecolors='none')

    # Axes & labels
    plt.yticks(np.arange(len(names_sub)), names_sub)
    plt.xlabel(xlabel)
    plt.title(title)
    plt.grid(axis='x', linestyle=':', alpha=0.4)
    plt.tight_layout()

    # Optional colorbar
    if color_by in ("feature", "attr") and cmap is not None:
        cbar = plt.colorbar(sc, pad=0.01)
        cbar.set_label("Feature value" if color_by == "feature" else "Attribution (normalized)")

    plt.show()


### Saliency

In [None]:
xai_saliency = Saliency(model)

attr = xai_saliency.attribute(X_test_t.requires_grad_(True).to(device),
                              target=y_pred_t,
                              abs=False
                              )  # binary logit
attr = attr.detach().cpu().numpy()

In [None]:
# Normalize the attributions (because Captum just return the gradients)

eps = 1e-16
denom = np.max(np.abs(attr), axis=0) + eps     # shape [n_features]
attr_norm = attr / denom

#### Generate and visualize

In [None]:
import numpy as np
from typing import Optional, Sequence, Tuple

def tabular_baseline_replacement_by_indices(
    arr: np.ndarray,                  # shape (N, F)
    indices: np.ndarray,              # shape (N, K) indices to replace per sample
    *,
    baselines: np.ndarray,            # shape (F,) per-feature baseline values
    ordinal_idx: Optional[Sequence[int]] = None,  # indices of ordinal features
    clip_bounds: Optional[Tuple[np.ndarray, np.ndarray]] = None,  # (mins, maxs), each shape (F,)
    round_ordinals: bool = True,
    perturb_baseline=None,
) -> np.ndarray:
    """
    Replace selected features with feature-wise baselines.
    Optionally round ordinal features to integer codes and clip to valid ranges.
    """
    out = arr.copy()
    N, F = out.shape
    if ordinal_idx is None:
        ordinal_idx = []
    ordinal_idx = np.asarray(ordinal_idx, dtype=int)

    for i in range(N):
        js = indices[i]  # features to replace for sample i
        out[i, js] = baselines[js]

        if len(ordinal_idx) and round_ordinals:
            # round only the ordinal columns that were touched
            touched_ord = np.intersect1d(js, ordinal_idx, assume_unique=False)
            if touched_ord.size:
                out[i, touched_ord] = np.rint(out[i, touched_ord])

        if clip_bounds is not None:
            mins, maxs = clip_bounds
            out[i] = np.minimum(np.maximum(out[i], mins), maxs)

    return out


cat_cols_idx = X_df.columns.get_indexer(cat_cols)

# Continuous → mean; Ordinal → median (then round)
means = X_train[:, ~cat_cols_idx].mean(axis=0)
meds  = np.median(X_train[:, cat_cols_idx], axis=0)

baselines = np.empty(X_train.shape[1], dtype=float)
baselines[~cat_cols_idx] = means
baselines[ cat_cols_idx] = np.rint(meds)  # integer code for ordinal

ordinal_idx = np.where(cat_cols_idx)[0]

# Optional (but helpful): per-feature clip bounds from train set
mins = X_train.min(axis=0)
maxs = X_train.max(axis=0)
clip_bounds = (mins, maxs)

metric = quantus.FaithfulnessEstimate(
    features_in_step=1,           # ↑ → faster, ↓ → more resolution
    abs=False, normalise=False,   # start simple for tabular
    perturb_func=tabular_baseline_replacement_by_indices,
    perturb_func_kwargs={
        "baselines": baselines,
        "ordinal_idx": ordinal_idx,
        "clip_bounds": clip_bounds,
        "round_ordinals": True,
    },
    similarity_func=None,         # default Pearson is fine
)

In [None]:
import numpy as np
import torch
import quantus
from captum.attr import Saliency

# --- 1. Define your model
model.eval().to(device)

# --- 2. Wrap Captum's Saliency into a callable
xai_saliency = Saliency(model)

def saliency_explainer(model, inputs, targets, **kwargs):
    """
    Expected signature for custom explain functions in quantus.evaluate.
    Must return np.ndarray of same shape as inputs.
    """
    # Convert numpy -> torch
    x_t = torch.tensor(inputs, dtype=torch.float32, device=device, requires_grad=True)
    y_t = torch.tensor(targets, dtype=torch.long, device=device)


    # Compute attributions (Captum expects tensor inputs)
    attributions = xai_saliency.attribute(x_t, target=y_t, abs=False)
    return attributions.detach().cpu().numpy()

# --- 3. Metric and config
metrics = {
    "RIS": quantus.RelativeInputStability(nr_samples=5),
    "ROS": quantus.RelativeOutputStability(nr_samples=5),
    "Consistency": quantus.Consistency(discretise_func=quantus.functions.discretise_func.top_n_sign,
                                       return_aggregate=False,),
    "Sufficiency": quantus.Sufficiency(threshold=0.6,
                                       return_aggregate=False,
                                        ),
    "Faithfulness": quantus.FaithfulnessEstimate(abs=False,
                                                 normalise=False,
                                                features_in_step=1,  
                                                perturb_baseline="mean",
                                                ),
}

xai_methods = {
    "Saliency": saliency_explainer,  
}

# explain_func_kwargs → required (even if empty)
explain_func_kwargs = {}

# call_kwargs for the metric
call_kwargs = {"run": {"device": device}}

# --- 4. Evaluate
results = quantus.evaluate(
    metrics=metrics,
    xai_methods=xai_methods,
    model=model,
    x_batch=X_test[:100],  # numpy array
    y_batch=y_test[:100],  # numpy array
    #agg_func=np.mean,
    explain_func_kwargs=explain_func_kwargs,
    call_kwargs=call_kwargs,
    #return_as_df=True,
    verbose=True,
)

In [None]:
visualize_importances(feature_names=feature_names_all,
                      importances=np.mean(attr_norm, axis=0)
                      )

In [None]:
beeswarm_attributions(attrs=attr_norm,
                      X=X_test,
                      feature_names=feature_names_all,
                      cmap="coolwarm",
                      color_by='attr')

#### Compute explanations metrics

In [None]:
quantus.AVAILABLE_METRICS['Faithfulness']

In [None]:
def saliency_wrapper(model: nn.Module,
                     inputs: np.ndarray,
                     targets:np.ndarray, **kwargs
                     ) -> np.ndarray:

    # Convert numpy -> torch
    x_t = torch.tensor(inputs, dtype=torch.float32, device=device, requires_grad=True)
    y_t = torch.tensor(targets, dtype=torch.long, device=device)


    # Compute attributions (Captum expects tensor inputs)
    attributions = xai_saliency.attribute(x_t, target=y_t, abs=False)
    return attributions.detach().cpu().numpy()


In [None]:
def get_quantus_metrics_multi(model: nn.Module,
                              xai_wrappers: List[Callable],
                              xai_names: List[str],
                              x_data: np.ndarray,
                              y_data: np.ndarray,
                              device: torch.device,
                              verbose: bool = True) -> Dict:
    """
    Evaluates a model's XAI methods against a set of Quantus metrics.

    Args:
        model: The PyTorch model (nn.Module).
        xai_wrappers: A list of Callable XAI explanation functions/wrappers.
        xai_names: A list of strings corresponding to the names of the XAI methods.
                   Must be the same length as xai_wrappers.
        x_data: Input data batch (np.ndarray).
        y_data: Target label batch (np.ndarray).
        verbose: If True, prints evaluation details.

    Returns:
        A dictionary containing the evaluation results from quantus.evaluate.
    """
    if len(xai_wrappers) != len(xai_names):
        raise ValueError("The length of 'xai_wrappers' must match the length of 'xai_names'.")

    # Ensure model is in evaluation mode and on the correct device
    model.eval().to(device)

    # Define Quantus metrics
    metrics = {
        "RIS": quantus.RelativeInputStability(nr_samples=5),
        "ROS": quantus.RelativeOutputStability(nr_samples=5),
        "Consistency": quantus.Consistency(discretise_func=quantus.functions.discretise_func.top_n_sign,
                                           return_aggregate=False),
        "Sufficiency": quantus.Sufficiency(threshold=0.6,
                                           return_aggregate=False),
        "Faithfulness": quantus.FaithfulnessEstimate(abs=False,
                                                     normalise=False,
                                                     features_in_step=1,  
                                                     perturb_baseline="mean"),
    }
    
    # Construct the XAI methods dictionary
    # This uses a dictionary comprehension to map names to functions
    xai_methods = dict(zip(xai_names, xai_wrappers))

    # Quantus config 
    # explain_func_kwargs → required (even if empty)
    explain_func_kwargs = {}

    # call_kwargs for the metric
    call_kwargs = {"run": {"device": device}}

    # Evaluate
    results = quantus.evaluate(
        metrics=metrics,
        xai_methods=xai_methods,
        model=model,
        x_batch=x_data,
        y_batch=y_data,
        explain_func_kwargs=explain_func_kwargs,
        call_kwargs=call_kwargs,
        verbose=verbose,
    )
    return results


In [None]:
def get_metrics_results(results: dict) -> dict:
    """
    Extracts metric data from the first set of results, cleans it by
    removing NaN values, and returns the cleaned data as a dictionary
    of NumPy arrays.

    Args:
        results: A dictionary where the values are metric dictionaries
                 (e.g., {'run_1': {'metric_a': [1, 2, np.nan], ...}}).

    Returns:
        A dictionary mapping metric names to their cleaned 1D NumPy arrays.
    """
    if not results:
        return {}

    # Get the metric dictionary from the first result set
    first_metrics_data = next(iter(results.values()))

    # Process each metric in the first set
    cleaned_metrics = {}
    for metric_name, values_list in first_metrics_data.items():
        metric_array = np.array(values_list, dtype=np.float64)

        # Remove NaN values using Boolean indexing
        #cleaned_array = metric_array[~np.isnan(metric_array)]
        
        #cleaned_metrics[metric_name] = cleaned_array
        cleaned_metrics[metric_name] = metric_array
    
    return cleaned_metrics

In [None]:
from captum.attr._utils.attribution import GradientAttribution
from captum.metrics import sensitivity_max, infidelity

def get_captum_metrics(captum_attribution: GradientAttribution,
                       inputs: torch.Tensor,
                       target: torch.Tensor,
                       attribution: torch.Tensor,
                       device: torch.device) -> Tuple[np.ndarray, np.ndarray]:
    """Get Captum Sensitivity and Infidelity metrics.

    Args:
        captum_attribution (GradientAttribution): Captum attribution object.
        inputs (torch.Tensor): Data features from which the attribution are computed.
        target (torch.Tensor): Target predictions.
        attribution (torch.Tensor): Captum attributions.
        device (torch.device): Torch device.

    Returns:
        Tuple[np.ndarray, np.ndarray]: Sensitivity and Infidelity values.
    """

    def perturb_fn(inputs):
        noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).to(inputs.device).float()
        return noise, inputs - noise
    
    sens = sensitivity_max(captum_attribution.attribute,
                           inputs.to(device),
                           target=target.to(device)
                           ).detach().cpu().numpy()

    # Computes infidelity score for saliency maps
    infid = infidelity(model, perturb_fn,
                       inputs=inputs.to(device),
                       attributions=attribution.to(device),
                       target=target.to(device)
                       ).detach().cpu().numpy()
    return sens, infid


In [None]:
quantus_metrics = get_quantus_metrics_multi(model=model,
                                            xai_wrappers=[saliency_wrapper],
                                            xai_names=["saliency"],
                                            x_data=X_test,
                                            y_data=pred, # The target
                                            device=device,
                                            verbose= True)

In [None]:
captum_metrics = get_captum_metrics(captum_attribution=xai_saliency,
                                    inputs=X_test_t,
                                    target=y_pred_t,
                                    attribution=torch.from_numpy(attr),
                                    device=device)

In [None]:
all_metrics = get_metrics_results(quantus_metrics)
all_metrics["Sensitivity"] = captum_metrics[0]
all_metrics["Infidelity"] = captum_metrics[1]

In [None]:
# Beware data may contains nan values as RIS/ROS metrics may not always give a result

In [None]:
all_metrics

### SmoothGrad (NoiseTunnel over Saliency)

In [None]:
xai_sg = NoiseTunnel(xai_saliency)
attr = xai_sg.attribute(X_test_t.requires_grad_(True).to(device),
                        target=y_pred_t,
                        nt_type='smoothgrad',
                        stdevs=0.1,
                        nt_samples=10, # Lower this to reduce computational time
                        abs=False)

attr = attr.detach().cpu().numpy()

# Normalize the attributions (because Captum just return the gradients)

eps = 1e-16
denom = np.max(np.abs(attr), axis=0) + eps     # shape [n_features]
attr_norm = attr / denom

#### Visualization

In [None]:
visualize_importances(feature_names=feature_names_all,
                      importances=np.mean(attr_norm, axis=0)
                      )

In [None]:
beeswarm_attributions(attrs=attr_norm,
                      X=X_test,
                      feature_names=feature_names_all,
                      cmap="coolwarm",
                      color_by='attr')

#### Explanation metrics

In [None]:
def smoothgrad_wrapper(model: nn.Module,
                     inputs: np.ndarray,
                     targets:np.ndarray,
                     **kwargs
                     ) -> np.ndarray:

    # Convert numpy -> torch
    x_t = torch.tensor(inputs, dtype=torch.float32, device=device, requires_grad=True)
    y_t = torch.tensor(targets, dtype=torch.long, device=device)


    # Compute attributions (Captum expects tensor inputs)
    attributions = xai_sg.attribute(x_t,
                        target=y_t,
                        nt_type='smoothgrad',
                        stdevs=0.1,
                        nt_samples=10, # Lower this to reduce computational time
                        abs=False)

    return attributions.detach().cpu().numpy()


In [None]:
quantus_metrics = get_quantus_metrics_multi(model=model,
                                            xai_wrappers=[smoothgrad_wrapper],
                                            xai_names=["smoothgrad"],
                                            x_data=X_test[:10],
                                            y_data=pred[:10], # The target
                                            device=device,
                                            verbose= True)

In [None]:
captum_metrics = get_captum_metrics(captum_attribution=xai_sg,
                                    inputs=X_test_t,
                                    target=y_pred_t,
                                    attribution=torch.from_numpy(attr),
                                    device=device)

In [None]:
all_metrics = get_metrics_results(quantus_metrics)
all_metrics["Sensitivity"] = captum_metrics[0]
all_metrics["Infidelity"] = captum_metrics[1]

In [None]:
all_metrics

### Input x Gradient

In [None]:
xai_ig = InputXGradient(model)

attr = xai_ig.attribute(X_test_t.requires_grad_(True).to(device),
                        target=y_pred_t,
                        )  # binary logit
attr = attr.detach().cpu().numpy()

# Normalize the attributions (because Captum just return the gradients)

eps = 1e-16
denom = np.max(np.abs(attr), axis=0) + eps     # shape [n_features]
attr_norm = attr / denom

#### Visualization

In [None]:
visualize_importances(feature_names=feature_names_all,
                      importances=np.mean(attr_norm, axis=0)
                      )

In [None]:
beeswarm_attributions(attrs=attr_norm,
                      X=X_test,
                      feature_names=feature_names_all,
                      cmap="coolwarm",
                      color_by='attr')

#### Explanation metrics

In [None]:
def inp_grad_wrapper(model, inputs, targets, **kwargs):
    # Convert numpy -> torch
    x_t = torch.tensor(inputs, dtype=torch.float32, device=device, requires_grad=True)
    y_t = torch.tensor(targets, dtype=torch.long, device=device)


    attributions = xai_ig.attribute(x_t,
                                    target=y_t,
                                    )
    return attributions.detach().cpu().numpy()


In [None]:
quantus_metrics = get_quantus_metrics_multi(model=model,
                                            xai_wrappers=[inp_grad_wrapper],
                                            xai_names=["inputgrad"],
                                            x_data=X_test[:10],
                                            y_data=pred[:10], # The target
                                            device=device,
                                            verbose= True)

In [None]:
captum_metrics = get_captum_metrics(captum_attribution=xai_ig,
                                    inputs=X_test_t[:10],
                                    target=y_pred_t[:10],
                                    attribution=torch.from_numpy(attr)[:10],
                                    device=device)

In [None]:
all_metrics = get_metrics_results(quantus_metrics)
all_metrics["Sensitivity"] = captum_metrics[0]
all_metrics["Infidelity"] = captum_metrics[1]
all_metrics

### Integrated Gradients

In [None]:
xai_int_grad = IntegratedGradients(model)
baseline = X_train_t.mean(dim=0) # Take on input (as the mean of the training set)
baselines = baseline.repeat(X_test_t.shape[0], 1).to(device) # Cast to the shape of the test set

In [None]:
attr = xai_int_grad.attribute(X_test_t.requires_grad_(True).to(device),
                        target=y_pred_t,
                        baselines=baselines,
                        n_steps=50
                        )
attr = attr.detach().cpu().numpy()

# Normalize the attributions (because Captum just return the gradients)
eps = 1e-16
denom = np.max(np.abs(attr), axis=0) + eps     # shape [n_features]
attr_norm = attr / denom

#### Visualization

In [None]:
visualize_importances(feature_names=feature_names_all,
                      importances=np.mean(attr_norm, axis=0)
                      )

In [None]:
beeswarm_attributions(attrs=attr_norm,
                      X=X_test,
                      feature_names=feature_names_all,
                      cmap="coolwarm",
                      color_by='attr')

#### Explanation metrics

In [None]:
def intgrad_wrapper(model, inputs, targets, **kwargs):

    # Convert numpy -> torch
    x_t = torch.tensor(inputs, dtype=torch.float32, device=device, requires_grad=True)
    y_t = torch.tensor(targets, dtype=torch.long, device=device)
    baseline = X_train_t.mean(dim=0)
    baselines = baseline.repeat(x_t.shape[0], 1).to(device) # Cast to the shape of the test set

    attributions = xai_int_grad.attribute(x_t,
                        target=y_t,
                        baselines=baselines,
                        n_steps=50
                        )
    return attributions.detach().cpu().numpy()

In [None]:
quantus_metrics = get_quantus_metrics_multi(model=model,
                                            xai_wrappers=[intgrad_wrapper],
                                            xai_names=["integrated_gradients"],
                                            x_data=X_test[:10],
                                            y_data=pred[:10], # The target
                                            device=device,
                                            verbose= True)

In [None]:
captum_metrics = get_captum_metrics(captum_attribution=xai_int_grad,
                                    inputs=X_test_t[:10],
                                    target=y_pred_t[:10],
                                    attribution=torch.from_numpy(attr)[:10],
                                    device=device)

In [None]:
all_metrics = get_metrics_results(quantus_metrics)
all_metrics["Sensitivity"] = captum_metrics[0]
all_metrics["Infidelity"] = captum_metrics[1]
all_metrics

### LRP (Layer-wise Relevance Propagation)

In [None]:
from zennit.composites import EpsilonPlus


# create a composite instance
composite = EpsilonPlus()

# use the following instead to ignore bias for the relevance
# composite = EpsilonPlusFlat(zero_params='bias')

# make sure the input requires a gradien

xai_lrp_grad = Gradient(model, composite)

targets = torch.nn.functional.one_hot(y_pred_t, num_classes=2).float()

with xai_lrp_grad:
     # gradient/ relevance wrt. output/class 1
     output, attr = xai_lrp_grad(X_test_t.requires_grad_(True).to(device),
                                 targets
                                 )

attr = attr.detach().cpu().numpy()

# Normalize the attributions (because Captum just return the gradients)

eps = 1e-16
denom = np.max(np.abs(attr), axis=0) + eps     # shape [n_features]
attr_norm = attr / denom

#### Visualization

In [None]:
visualize_importances(feature_names=feature_names_all,
                      importances=np.mean(attr_norm, axis=0)
                      )

In [None]:
beeswarm_attributions(attrs=attr_norm,
                      X=X_test,
                      feature_names=feature_names_all,
                      cmap="coolwarm",
                      color_by='attr')

#### Explanation metrics

In [None]:
def lrp_wrapper(model, inputs, targets, **kwargs):
    # Convert numpy -> torch
    x_t = torch.tensor(inputs, dtype=torch.float32, device=device, requires_grad=True)
    y_t = torch.tensor(targets, dtype=torch.long, device=device)
    y_t = torch.nn.functional.one_hot(y_t, num_classes=2).float()


    with xai_lrp_grad:
     # gradient/ relevance wrt. output/class 1
     _, attributions = xai_lrp_grad(x_t,
                            y_t
                            )
     return attributions.detach().cpu().numpy()

In [None]:
quantus_metrics = get_quantus_metrics_multi(model=model,
                                            xai_wrappers=[lrp_wrapper],
                                            xai_names=["lrp"],
                                            x_data=X_test[:10],
                                            y_data=pred[:10], # The target
                                            device=device,
                                            verbose= True)

In [None]:
captum_metrics = get_captum_metrics(captum_attribution=xai_lrp_grad,
                                    inputs=X_test_t,
                                    target=y_pred_t,
                                    attribution=torch.from_numpy(attr),
                                    device=device)

## Counterfactuals

In [None]:
# We will attempt to minimally change continuous features to flip the prediction.
# Categorical one-hots will be kept fixed for simplicity.

# Identify indices of continuous (original numeric) features in processed vector
num_indices = []
for i, name in enumerate(feature_names_all):
    if name in num_cols:
        num_indices.append(i)
num_indices = np.array(num_indices, dtype=int)

def predict_proba(model, x):
    with torch.no_grad():
        p = torch.sigmoid(model(x))
    return p

def counterfactual_search(x_init, target_label=1, steps=300, lr=0.05, l2=0.01):
    x = x_init.clone().detach().to(device)
    x_cf = x.clone().detach().requires_grad_(True)
    opt = torch.optim.Adam([x_cf], lr=lr)
    target = torch.tensor([[float(target_label)]], device=device)
    for t in range(steps):
        opt.zero_grad()
        logit = model(x_cf)
        prob = torch.sigmoid(logit)
        # BCE encouraging target
        bce = nn.functional.binary_cross_entropy(prob, target)
        # Proximity on numeric dims only
        diff = (x_cf - x)
        diff_masked = diff[:, num_indices]
        prox = l2 * torch.sum(diff_masked**2)
        loss = bce + prox
        loss.backward()
        # Project: keep categorical one-hots unchanged
        with torch.no_grad():
            # freeze categorical parts by copying from original
            for i in range(x_cf.shape[1]):
                if i not in num_indices:
                    x_cf[:, i] = x[:, i]
        opt.step()
        # Early stop if flipped
        if (target_label == 1 and prob.item() >= 0.5) or (target_label == 0 and prob.item() < 0.5):
            break
    return x_cf.detach(), prob.item()

# Choose an instance predicted as 0 => try to flip to 1 (or vice versa)
model.eval()
with torch.no_grad():
    probs_test = torch.sigmoid(model(X_test_t.to(device))).cpu().numpy().ravel()
preds_test = (probs_test >= 0.5).astype(int)
idx0 = int(np.where(preds_test == 0)[0][0]) if (preds_test==0).any() else 0
x_orig = X_test_t[idx0:idx0+1]
y_pred_orig = preds_test[idx0]
target_label = 1 - y_pred_orig
x_cf, prob_after = counterfactual_search(x_orig, target_label=target_label, steps=400, lr=0.05, l2=0.01)

print('Original pred:', y_pred_orig, '→ Target:', target_label, '| New prob:', round(prob_after, 3))

# Show changed continuous features in original feature space (standardized units)
delta = (x_cf - x_orig).cpu().numpy().ravel()
changes = {name: delta[i] for i, name in enumerate(feature_names_all) if i in num_indices and abs(delta[i])>1e-6}
changes_sorted = dict(sorted(changes.items(), key=lambda kv: abs(kv[1]), reverse=True)[:12])
pd.DataFrame({'feature': list(changes_sorted.keys()), 'delta_std_units': list(changes_sorted.values())})


## Attacks on LIME and SHAP

In [None]:
# Fetch Adult from OpenML
adult = fetch_openml(name='adult', version=2, as_frame=True)
df = adult.frame.copy()

# Replace '?' with NaN and drop rows with missing (simple approach)
df = df.replace('?', np.nan).dropna()

# Target is 'class': '>50K' or '<=50K' — convert to 0/1
df['class'] = (df['class'] == '>50K').astype(int)

# Add a random column -- this is what we'll have LIME/SHAP explain.
df['unrelated_column'] = np.random.choice([0,1],size=df.shape[0])

# Identify categorical vs numeric columns
target_col = 'class'
X_df = df.drop(columns=[target_col])
y = df[target_col].values

categorical_feature_name = X_df.select_dtypes(include=['category','object']).columns.tolist()
categorical_feature_name += ["unrelated_column"]


In [None]:
# Build preprocessing pipeline: OrdinalEncoder for categoricals, Standardize numerics
preprocess = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), num_cols),
        # ('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), cat_cols),
        ('cat', OrdinalEncoder(), categorical_feature_name), # We use here OrdinalEncoder to limit the number of features
    ]
)

X_processed = preprocess.fit_transform(X_df)
feature_names_num = num_cols
feature_names_cat = list(preprocess.named_transformers_['cat'].get_feature_names_out(categorical_feature_name))
feature_names_all = feature_names_num + feature_names_cat

categorical_feature_indcs = [feature_names_all.index(cat) for cat in feature_names_cat if cat in feature_names_all]


# Split into train test
X_train, X_test, y_train, y_test = train_test_split(
    X_processed, y, test_size=0.2, random_state=42, stratify=y
)


In [None]:
sex_indc = feature_names_all.index('sex')
unrelated_indcs = feature_names_all.index('unrelated_column')

In [None]:
negative_outcome = 0
positive_outcome = 1

In [None]:
def one_hot_encode(y):
    """ One hot encode y for binary features.  We use this to get from 1 dim ys to predict proba's.
    This is taken from this s.o. post: https://stackoverflow.com/questions/29831489/convert-array-of-indices-to-1-hot-encoded-numpy-array

    Parameters
    ----------
    y : np.ndarray

    Returns
    ----------
    A np.ndarray of the one hot encoded data.
    """
    y_hat_one_hot = np.zeros((len(y), 2))
    y_hat_one_hot[np.arange(len(y)), y] = 1
    return y_hat_one_hot

In [None]:
rows = []
enc = preprocess.named_transformers_['cat']
for j, col in enumerate(categorical_feature_name):
    for i, cat in enumerate(enc.categories_[j]):
        rows.append({'column': col, 'category': cat, 'code': i})
    if enc.handle_unknown == 'use_encoded_value':
        rows.append({'column': col, 'category': '<UNKNOWN>', 'code': enc.unknown_value})

In [None]:
rows

In [None]:
class sexist_model_f:
    # Decision rule: classify negatively if race is black
    def predict(self, X):
        return np.array([negative_outcome if x[sex_indc] == 0 else positive_outcome for x in X])

    def predict_proba(self, X): 
        return one_hot_encode(self.predict(X))

    def score(self, X,y):
        return np.sum(self.predict(X)==y) / len(X)
    
class innocuous_model_psi:
    # Decision rule: classify according to randomly drawn column 'unrelated column'
    def predict(self,X):
        return np.array([negative_outcome if x[unrelated_indcs] > 0 else positive_outcome for x in X])

    def predict_proba(self, X): 
        return one_hot_encode(self.predict(X))

    def score(self, X,y):
        return np.sum(self.predict(X)==y) / len(X)

### Fooling LIME

In [None]:
from fooling_lime_shap import Adversarial_Lime_Model

# Train the adversarial model for LIME with f and psi 
adv_lime = Adversarial_Lime_Model(sexist_model_f(), innocuous_model_psi()).\
            train(X_train[:100], y_train[:100], feature_names=feature_names_all, categorical_features=categorical_feature_indcs)

In [None]:
sex_indc

In [None]:
import lime

# Let's just look at a the first example in the test set
ex_indc = np.random.choice(X_test.shape[0])

# To get a baseline, we'll look at LIME applied to the biased model f
normal_explainer = lime.lime_tabular.LimeTabularExplainer(X_train, feature_names=adv_lime.get_column_names(),
                                                          discretize_continuous=False,
                                                          categorical_features=categorical_feature_indcs)

normal_exp = normal_explainer.explain_instance(X_test[ex_indc], sexist_model_f().predict_proba)

# print ("Explanation on biased f:\n",normal_exp[:3],"\n\n")

# Now, lets look at the explanations on the adversarial model 
adv_explainer = lime.lime_tabular.LimeTabularExplainer(X_train,feature_names=adv_lime.get_column_names(), 
                                                       discretize_continuous=False,
                                                       categorical_features=categorical_feature_indcs)

adv_exp = adv_explainer.explain_instance(X_test[ex_indc], adv_lime.predict_proba)

# print ("Explanation on adversarial model:\n",adv_exp[:3],"\n")

# print("Prediction fidelity: {0:3.2}".format(adv_lime.fidelity(X_test[ex_indc:ex_indc+1])))

In [None]:
from IPython.display import display, HTML

html = normal_exp.as_html()
display(HTML(html))

In [None]:
html = adv_exp.as_html()
display(HTML(html))

### Fooling SHAP

In [None]:
from fooling_lime_shap import Adversarial_Kernel_SHAP_Model


# Train the adversarial model
adv_shap = Adversarial_Kernel_SHAP_Model(sexist_model_f(), innocuous_model_psi()).\
            train(X_train[:100], y_train[:100], feature_names=feature_names_all)

In [None]:
import shap

# Set the background distribution for the shap explainer using kmeans
background_distribution = shap.kmeans(X_train, 100)

# Let's use the shap kernel explainer and grab a point to explain
to_examine = np.random.choice(X_test.shape[0])

# Explain the biased model
biased_kernel_explainer = shap.KernelExplainer(sexist_model_f().predict, background_distribution)
biased_shap_values = biased_kernel_explainer.shap_values(X_test[to_examine:to_examine+1])

# Explain the adversarial model
adv_kerenel_explainer = shap.KernelExplainer(adv_shap.predict, background_distribution)
adv_shap_values = adv_kerenel_explainer.shap_values(X_test[to_examine:to_examine+1])


In [None]:
biased_shap_values

In [None]:
adv_shap_values

In [None]:
# Plot it using SHAP's plotting features.
shap.summary_plot(biased_shap_values, feature_names=feature_names_all, plot_type="bar")


In [None]:
shap.summary_plot(adv_shap_values, feature_names=feature_names_all, plot_type="bar")
