<a href="https://colab.research.google.com/github/VladimirBoshnjakovski/explainable-ai-thesis-code/blob/main/06_xai_shap_local.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ──────────────── IMPORTS ────────────────
from google.colab import files
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    log_loss,
    classification_report,
    confusion_matrix
)

import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, regularizers, optimizers

import shap

In [None]:
# ──────────────── FILE UPLOAD ────────────────
uploaded = files.upload()  # Opens a file upload dialog in Colab; allows user to select a local file

# Reads the first uploaded file into a pandas DataFrame
# 'uploaded' is a dictionary where keys are filenames; 'next(iter(uploaded))' gets the first filename
df = pd.read_csv(next(iter(uploaded)))

In [None]:
# ──────────────── PREPROCESS ────────────────
# assume df already loaded
df = df.drop(columns=['source']) \
       .rename(columns={'Presence of Heart Disease (1=Yes)': 'target'})

X = df.drop(columns=['target']).values
y = df['target'].values

# stratified train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

# standardize
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test  = scaler.transform(X_test)

# ──────────────── MODEL BUILDER ────────────────
def build_model(input_dim):
    l2 = regularizers.l2(1e-4)
    model = models.Sequential([
        layers.Input(shape=(input_dim,)),
        layers.Dense(64, activation='relu', kernel_regularizer=l2),
        layers.Dropout(0.3),
        layers.Dense(32, activation='relu', kernel_regularizer=l2),
        layers.Dropout(0.3),
        layers.Dense(16, activation='relu', kernel_regularizer=l2),
        layers.Dense(1, activation='sigmoid')
    ])

    # lower learning rate, track AUC
    opt = optimizers.Adam(learning_rate=1e-3)
    model.compile(
        optimizer=opt,
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )
    return model

model = build_model(X_train.shape[1])

# ──────────────── CALLBACKS ────────────────
es = callbacks.EarlyStopping(
    monitor='val_auc',      # stop when AUC stops improving
    mode='max',             # we want to MAXIMIZE AUC
    patience=10,
    restore_best_weights=True
)

rlr = callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-5,
    verbose=1
)

# ──────────────── TRAIN ────────────────
history = model.fit(
    X_train, y_train,
    validation_split=0.2,
    epochs=100,
    batch_size=16,
    callbacks=[es, rlr],
    verbose=2
)

# ──────────────── EVALUATE ────────────────
loss, acc, auc = model.evaluate(X_test, y_test, verbose=0)
print(f"Test loss: {loss:.4f}   |   Test accuracy: {acc:.4f}   |   Test AUC: {auc:.4f}")

# Predict on test set
y_probs = model.predict(X_test, verbose=0).flatten()
y_preds = (y_probs >= 0.5).astype(int)

# Compute confusion matrix
cm = confusion_matrix(y_test, y_preds)

# Plot confusion matrix
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix on Test Set')
plt.tight_layout()
plt.show()


In [None]:
# ──────────────── SHAP EXPLANATION SETUP ────────────────
explainer = shap.KernelExplainer(model.predict, X_train)  # Initialize the SHAP KernelExplainer using the model's predict function and training data as background

# Select a specific instance from the test set for local explanation
instance_idx = 40                                           # Index of the instance to explain
x_instance = X_test[instance_idx].reshape(1, -1)            # Reshape to 2D array for SHAP compatibility

In [None]:
# ─────────────────────────────────────────────────────────────
# 🔍 Benchmarking SHAP Explanation Time on a Single Instance
# ─────────────────────────────────────────────────────────────
# This block runs SHAP explanations 5 times on the same instance
# to measure the average run time and observe any variability.
# This is useful when evaluating the feasibility of integrating SHAP
# into time-sensitive or resource-constrained environments.
# ─────────────────────────────────────────────────────────────

import time

# Assuming 'explainer' is a pre-initialized SHAP explainer object
# and 'x_instance' is a single-row DataFrame or Series (one sample)
shap_values_list = []
time_taken_list = []

for i in range(5):
    start_time = time.time()
    shap_values = explainer.shap_values(x_instance)
    end_time = time.time()

    shap_values_list.append(shap_values)
    time_taken_list.append(end_time - start_time)
    print(f"Run {i+1} - Time taken: {time_taken_list[-1]:.4f} seconds")

# Summary of timing
print("\n⏱️ SHAP explanation time summary:")
print(f"Average time: {np.mean(time_taken_list):.4f} seconds")
print(f"Min time:     {np.min(time_taken_list):.4f} seconds")
print(f"Max time:     {np.max(time_taken_list):.4f} seconds")


In [None]:
# ────────────────────────────────────────────────────────────────
# 📊 Visualizing Top SHAP Features for a Single Prediction Instance
# ────────────────────────────────────────────────────────────────
# This block extracts SHAP values for a specific instance, maps them to feature names,
# and visualizes the top 8 most influential features using a horizontal bar plot.
# ────────────────────────────────────────────────────────────────

# Example for a specific instance (let's assume instance index 40 as an example)
instance_idx = 40
shap_values = shap_values_list[0]  # Treating this as an example (first run)

# Flatten the SHAP values from the run
shap_values_flat = shap_values[0].flatten()

# Get the feature names from your dataset
feature_names = df.drop(columns=['target']).columns.tolist()

# Create a DataFrame to map SHAP values to feature names
shap_values_df = pd.DataFrame(shap_values_flat, columns=["SHAP Value"])
shap_values_df["Feature"] = feature_names  # Link feature names to SHAP values

# Compute the absolute value for sorting (to find most influential features)
shap_values_df["abs_SHAP Value"] = shap_values_df["SHAP Value"].abs()

# Sort the DataFrame by absolute SHAP Value in descending order
shap_values_df = shap_values_df.sort_values(by="abs_SHAP Value", ascending=False)

# Select top 8 features with the highest absolute SHAP values
top_features = shap_values_df.head(8)

# Set a sketch-like style with Seaborn
sns.set(style="white", palette="muted", font_scale=1.2)

# Create the plot
plt.figure(figsize=(10, 6))

# Plot using Seaborn's barplot with coloring based on the "Set2" color palette
sns.barplot(
    x="SHAP Value",
    y="Feature",
    data=top_features,
    color=sns.color_palette("Set2", n_colors=1)[0],
    edgecolor='black',
    linewidth=2
)

# Add gridlines for a "net" effect
plt.grid(True, axis='x', linestyle='--', linewidth=0.7, color='gray', alpha=0.5)

# Set labels and title with a bit more padding and font size
plt.xlabel('SHAP Value', fontsize=12, labelpad=10)
plt.ylabel('Feature', fontsize=12, labelpad=10)

# Title in two lines with bold font
plt.title(
    f'Top 8 Most Influential Features\nBased on Absolute SHAP Values for Instance {instance_idx}',
    fontsize=14,
    fontweight='bold',
    pad=20
)

# Adjust the layout to avoid clipping the labels
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 📌 SHAP Force Plot: Local Explanation for a Single Prediction (Top 8 Features)
# ──────────────────────────────────────────────────────────────────────────────
# This block generates a SHAP force plot for one test instance,
# focusing on the top 8 most influential features based on SHAP value magnitude.
# It uses already-computed SHAP values (from shap_values_list[0])
# and outputs both an interactive plot and an HTML export.
# ──────────────────────────────────────────────────────────────────────────────

# Initialize JavaScript rendering for SHAP plots
shap.initjs()

# Select a specific instance from the test set
instance_index = 40

# Extract and reshape the input instance for SHAP (1 sample, 2D)
x_instance_raw = X_test[instance_index].reshape(1, -1)

# Use already-computed SHAP values for the selected instance (positive class)
shap_value_instance = shap_values_list[0][0].flatten()  # 1D array of SHAP values

# Flatten feature values for the same instance
x_instance = x_instance_raw.flatten()

# Get original feature names (excluding target column)
feature_names = df.drop(columns='target').columns.tolist()

# Identify top 8 features with highest absolute SHAP values
top8_indices = np.argsort(np.abs(shap_value_instance))[::-1][:8]
shap_top8 = shap_value_instance[top8_indices]
x_top8 = x_instance[top8_indices]
feature_names_top8 = [feature_names[i] for i in top8_indices]

# Reshape feature values as required by SHAP (2D)
x_top8_2D = x_top8.reshape(1, -1)

# ────────────────────────────────────────────────────────────────
# 💡 Generate and Display SHAP Force Plot in Notebook
# ────────────────────────────────────────────────────────────────

# Create the force plot object for visualization
shap_plot = shap.force_plot(
    base_value     = explainer.expected_value[0],
    shap_values    = shap_top8,
    features       = x_top8_2D,
    feature_names  = feature_names_top8
)

# Display in notebook
shap_plot

# ────────────────────────────────────────────────────────────────
# 💾 Save Force Plot as Interactive HTML
# ────────────────────────────────────────────────────────────────
# This export allows embedding or sharing the force plot outside the notebook.
# ────────────────────────────────────────────────────────────────

shap.save_html("shap_forceplot_instance_40.html", shap_plot)
from google.colab import files
files.download("shap_forceplot_instance_40.html")


In [None]:
# Process each SHAP value for each run
shap_values_df_list = []

for shap_values in shap_values_list:
    shap_values_flat = shap_values[0].flatten()  # Flatten SHAP values

    # Create DataFrame
    shap_values_df = pd.DataFrame(shap_values_flat, columns=["SHAP Value"])
    shap_values_df["Feature"] = feature_names  # Link feature names to SHAP values
    shap_values_df["abs_SHAP Value"] = shap_values_df["SHAP Value"].abs()  # Absolute value for sorting

    shap_values_df_list.append(shap_values_df)

# Combine the SHAP values from all runs into one DataFrame
combined_shap_values = pd.concat([df.set_index("Feature")["SHAP Value"] for df in shap_values_df_list], axis=1)
combined_shap_values.columns = [f"Run {i+1}" for i in range(len(shap_values_df_list))]

# Sort the values based on the absolute value of SHAP values and select the top 5 features
combined_shap_values = combined_shap_values.reindex(combined_shap_values.abs().max(axis=1).sort_values(ascending=False).index).head(5)

# Reshape the data to be suitable for Seaborn
combined_shap_values = combined_shap_values.reset_index()
combined_shap_values = pd.melt(combined_shap_values, id_vars="Feature", var_name="Run", value_name="SHAP Value")

# Set a sketch-like style
sns.set(style="white", palette="muted", font_scale=1.2)

# Create the plot
plt.figure(figsize=(10, 6))

# Plot using Seaborn's barplot with a "sketch" theme
sns.barplot(x="SHAP Value", y="Feature", hue="Run", data=combined_shap_values, dodge=True,
            edgecolor='black', linewidth=2, capsize=5, errwidth=2)

# Add a vertical line at zero
plt.axvline(x=0, color='black', linestyle='--', linewidth=1.5)

# Add gridlines for the "net" effect (background grid)
plt.grid(True, linestyle='--', linewidth=0.7, color='gray', alpha=0.5)

# Set labels and title with a bit more padding and font size
plt.xlabel('SHAP Value', fontsize=12, labelpad=10)
plt.ylabel('Feature', fontsize=12, labelpad=10)
plt.title('SHAP Values for Top 5 Features Across Runs', fontsize=14, pad=20)

# Adjust the layout to avoid clipping the labels
plt.tight_layout()

# Show the plot
plt.show()


In [None]:
# ───────────────────────────────────────────────────────────────────────────────
# 📈 SHAP Value Stability Plot Across Multiple Runs (Top 5 Features)
# ───────────────────────────────────────────────────────────────────────────────
# This plot shows how SHAP values for each of the top 5 most influential features
# vary across repeated explanation runs. It helps assess the consistency and stability
# of the explanation method when applied multiple times to the same instance.
# ───────────────────────────────────────────────────────────────────────────────

# Set figure size
plt.figure(figsize=(10, 6))

# Plot SHAP value for each top feature across explanation runs
for feature in combined_shap_values["Feature"].unique():
    subset = combined_shap_values[combined_shap_values["Feature"] == feature]
    plt.plot(subset["Run"], subset["SHAP Value"], marker='o', label=feature)

# Add horizontal reference line at y=0
plt.axhline(0, color='black', linestyle='--')

# Add axis labels and title
plt.xlabel("Run")
plt.ylabel("SHAP Value")
plt.title("SHAP Values Across Runs for Top 5 Features")

# Place legend outside the plot for readability
plt.legend(title="Feature", bbox_to_anchor=(1.05, 1), loc='upper left')

# Improve spacing and appearance
plt.tight_layout()
plt.grid(True, linestyle='--', linewidth=0.5)

# Show the plot
plt.show()


In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 🔥 SHAP Value Heatmap Across Runs (Top 5 Features)
# ──────────────────────────────────────────────────────────────────────────────
# This heatmap visualizes how the SHAP values for the top 5 features vary across
# multiple explanation runs. Each cell shows the SHAP value of a given feature
# in a specific run. Color intensity and direction indicate the strength and
# polarity of feature influence (red = positive, blue = negative).
# ──────────────────────────────────────────────────────────────────────────────

# Pivot the DataFrame to get features as rows and runs as columns
pivot = combined_shap_values.pivot(index="Feature", columns="Run", values="SHAP Value")

# Set figure size
plt.figure(figsize=(8, 5))

# Plot heatmap using seaborn with diverging color palette centered at 0
sns.heatmap(
    pivot,
    annot=True,              # show numeric SHAP values
    cmap="coolwarm",         # blue = negative, red = positive
    center=0,                # center color gradient around zero
    linewidths=0.5,
    linecolor='gray'
)

# Add title and layout formatting
plt.title("SHAP Value Heatmap for Top 5 Features Across Runs")
plt.tight_layout()

# Display the heatmap
plt.show()


In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 🎯 SHAP Value Distribution per Feature (Dot Plot Across Runs)
# ──────────────────────────────────────────────────────────────────────────────
# This dot plot shows the distribution of SHAP values for each of the top 5 features
# across multiple runs of the explanation algorithm. Each dot represents the SHAP
# value assigned to a feature in one specific run. This visualization helps assess
# whether SHAP explanations are stable or if they vary meaningfully between runs.
# ──────────────────────────────────────────────────────────────────────────────

# Set figure size
plt.figure(figsize=(10, 6))

# Create strip plot with SHAP values on x-axis, features on y-axis, and color by run
sns.stripplot(
    data=combined_shap_values,
    x="SHAP Value",
    y="Feature",
    hue="Run",         # differentiate each run by color
    jitter=True,       # spread out overlapping dots for visibility
    dodge=True,        # separate dots by hue (Run)
    size=8             # size of the dots
)

# Add vertical line at zero for reference (neutral SHAP value)
plt.axvline(0, color='black', linestyle='--')

# Add plot title and adjust layout
plt.title("SHAP Values Distribution (Dot Plot)")
plt.tight_layout()

# Display the plot
plt.show()


In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 📊 Standard Deviation of SHAP Values per Feature (Variability Analysis)
# ──────────────────────────────────────────────────────────────────────────────
# This plot displays the standard deviation of SHAP values across runs for each feature.
# High variability may indicate instability in explanations, especially for sensitive
# or noisy features. Stable features will have low standard deviation across runs.
# ──────────────────────────────────────────────────────────────────────────────

# Compute standard deviation of SHAP values grouped by feature, sorted descending
std_dev_df = (
    combined_shap_values
    .groupby("Feature")["SHAP Value"]
    .std()
    .sort_values(ascending=False)
    .reset_index()
)

# Set figure size
plt.figure(figsize=(8, 5))

# Create bar plot of SHAP standard deviation per feature
sns.barplot(
    x="SHAP Value",
    y="Feature",
    data=std_dev_df,
    palette="OrRd"  # Orange-Red gradient to emphasize magnitude
)

# Add axis label and plot title
plt.xlabel("Standard Deviation of SHAP Value")
plt.title("Feature-wise SHAP Variability Across Runs")

# Optimize layout and render plot
plt.tight_layout()
plt.show()
