<img src="https://www.th-koeln.de/img/logo.svg" style="float:right;" width="200">

# Musterlösung / Sample solution 
## 10th exercise: <font color="#C70039">Interpretable Machine Learning with Shapley Values for image classification</font>
* Course: AML
* Lecturer: <a href="https://www.gernotheisenberg.de/">Gernot Heisenberg</a>
* Author of notebook: <a href="https://www.gernotheisenberg.de/">Gernot Heisenberg</a>
* Date:   04.08.2025

---------------------------------

### <font color="ce33ff">DESCRIPTION</font>:
This is one implementation example to demo XAI for image classification using the inbuild cifar-10 data set, that you have come across with in exercise 8 already.

In [None]:
import shap
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.datasets import cifar10

## Load dataset
### load build-in dataset and preprocess
Take the cifar-10 data set from exercise 8

In [None]:
# load data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Use original 32x32 images
x_train_processed = preprocess_input(np.array(x_train).astype(np.float32))
x_test_processed = preprocess_input(np.array(x_test).astype(np.float32))

# Reshape images for KernelExplainer (flatten pixels - see below)
# KernelExplainer expects 1 or 2 dimensions. We'll reshape to (samples, num_features)
num_test_samples = x_test_processed[:5].shape[0]
flattened_test_images = x_test_processed[:5].reshape(num_test_samples, -1)

num_train_samples = x_train_processed[:50].shape[0]
flattened_train_images = x_train_processed[:50].reshape(num_train_samples, -1)

# Explicitly convert background data to numpy array with float32 dtype
flattened_train_images = np.array(flattened_train_images).astype(np.float32)

## Modeling
### Use a pretrained Model (here MobileNetV2)

In [None]:
# Using include_top=False to remove the classification layer, as we have 10 classes in CIFAR-10
# Adjusted input_shape to 32x32
'''Note: Using imagenet weights on 32x32 input might not be ideal since it was trained on much larger imnages'''
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

# Add a new classification layer for CIFAR-10
x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
predictions = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.models.Model(inputs=base_model.input, outputs=predictions)

# Wrap the model's predict function for KernelExplainer
# This function needs to accept the flattened input and reshape it back for the model
def predict_fn_for_kernel(flattened_images):
    # Reshape flattened images back to original shape for prediction
    original_shape = (-1, 32, 32, 3)
    images = flattened_images.reshape(original_shape)
    return model.predict(images)

## Initialize the KernelExplainer 

In [None]:
# Kernel explainer works with the model's predict function
# It requires a background dataset in the flattened format
# Using a smaller subset for demonstration due to memory constraints
explainer = shap.KernelExplainer(predict_fn_for_kernel, flattened_train_images)

## Calculate Shapley-values for test images

In [None]:
# Using a smaller subset for demonstration
# nsamples can be increased for better accuracy, but increases computation time
shap_values = explainer.shap_values(flattened_test_images, nsamples=100)

## Visualization of an explanation example

In [None]:
# KernelExplainer's shap_values output format can vary.
# For multi-output models, it's often a list of arrays.
# We'll reshape the shap values back to image shape for plotting.
# The second argument to image_plot should be the original images.
# shap_values is a list, each element is a shapley value array for a class, shape (num_samples, num_features)
# We need to reshape each of these back to image shape (num_samples, height, width, channels)
shap_values_reshaped = [
    s.reshape((-1, 32, 32, 3)) for s in shap_values
]

# We'll plot the shap values for the first class (index 0).
# The second argument to image_plot should be the original images.
shap.image_plot(shap_values_reshaped[0], -x_test_processed[:5])