In [0]:
# "standard"
import numpy as np
import pandas as pd

# machine learning and statistics
import tensorflow as tf
from tensorflow import keras
from keras.callbacks import EarlyStopping
from tensorflow.keras.models import Sequential
from keras import applications, models, layers, optimizers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import confusion_matrix
import shap

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# misc
import cv2, magic, datetime, sys, os, wget, pickle, time, boto3, io, tempfile
from IPython.display import clear_output

# src
import sys
sys.path.append('/Workspace/Users/bjedelma@gmail.com/Alzheimers-MRI-Classification/src')
from visualize import visualize_training
from custom_pruning import global_prune_model
from data_io import save_model_s3, load_model_s3, save_pickle_s3, load_pickle_s3

clear_output(wait=False)

Mount AWS S3 bucket containing processed data

In [0]:
ACCESS_KEY = dbutils.secrets.get(scope="brad-aws", key="access_key")
SECRET_KEY= dbutils.secrets.get(scope="brad-aws", key="secret_key")

# specify bucket and mount point
AWS_S3_BUCKET = "databricks-workspace-stack-brad-personal-bucket/AD_MRI_classification/raw/"
MOUNT_NAME = f"/mnt/{AWS_S3_BUCKET.split('/')[-2]}"
SOURCE_URL = f"s3a://{AWS_S3_BUCKET}"
EXTRA_CONFIGS = { "fs.s3a.access.key": ACCESS_KEY, "fs.s3a.secret.key": SECRET_KEY}

# mount bucket
if any(mount.mountPoint == MOUNT_NAME for mount in dbutils.fs.mounts()):
    print(f"{MOUNT_NAME} is already mounted.")
else:
    dbutils.fs.mount(SOURCE_URL, MOUNT_NAME, extra_configs = EXTRA_CONFIGS)
    print(f"{MOUNT_NAME} is now mounted.")

In [0]:
# Load data file and unpack contents
bucket_name="databricks-workspace-stack-brad-personal-bucket"
s3_file_name='AD_MRI_classification/preprocessed/data_preprocessed.pkl'
data=load_pickle_s3(bucket_name, s3_file_name, dbutils)

train_data=data['train_data']
train_lab=data['train_labels']
test_data=data['test_data']
test_lab=data['test_labels']
    
# Convert labels to categorical
train_lab_cat = to_categorical(train_lab.astype('int8'), num_classes=4)
test_lab_cat = to_categorical(test_lab.astype('int8'), num_classes=4)

It is paramount to at least try to explain some of the image-based features that contribute to classification. In the current context, since the dataset is comprised of registered anatomical MRI, feature importance can implicate certain brain regions in AD disease progression.

In [0]:
# Load fine-tuned model
bucket_name = "databricks-workspace-stack-brad-personal-bucket"
s3_file_path = 'AD_MRI_classification/results/model_resnet50_fine_tune.h5'
pre_pruned_model = load_model_s3(bucket_name, s3_file_path, dbutils)

import shap
test_data2 = test_data.reshape((test_data.shape[0], 128, 128, 1))

# Initialize SHAP explainer
explainer = shap.GradientExplainer(pre_pruned_model, test_data2)

# Compute SHAP values for the sample image
sample = np.mean(test_data2, axis=0)  # Mean across the first dimension
sample = sample.reshape(1, 128, 128, 1)  # Reshape to (1, 128, 128, 1)
shap_values = explainer.shap_values(sample)  # Use a batch of one image

# Visualize SHAP values for the first class
clear_output(wait=False)
shap.image_plot(shap_values, sample)

Compute and average SHAP values across multiple samples to understand the most influential features (pixels) across the entire test set.

In [0]:
shap_values_all = explainer.shap_values(test_data2[:100])  # Analyze first 100 samples
mean_shap = np.mean(np.abs(shap_values_all), axis=0)
shap.image_plot(mean_shap, test_data2[:5])  # Visualize average importance for first 5 samples


For incorrectly classified samples, visualize SHAP values to identify patterns in feature misinterpretation.

In [0]:
misclassified = np.where(predictions != test_labels)[0]
for idx in misclassified[:5]:
    shap_values = explainer.shap_values(test_data2[idx:idx+1])
    shap.image_plot(shap_values, test_data2[idx:idx+1])


If the model has multiple output classes, compute SHAP values for each class and compare.

In [0]:
shap_values_class_0 = explainer.shap_values(sample, ranked_outputs=1)
shap.image_plot(shap_values_class_0, sample)


Cluster: Group similar samples based on their SHAP value patterns to identify clusters in the decision-making process.

In [0]:
shap.summary_plot(shap_values_all, test_data2, plot_type="bar")