In [None]:

import numpy as np
from pathlib import Path
import tensorflow as tf
import matplotlib.pyplot as plt

# Set paths
PREPROCESSED_DIR = Path("./experiments/preprocessed")
MODEL_DIR = Path("./experiments/models/cnn/xray_cnn_model")

# Load preprocessed X-ray data
xray_data = np.load(PREPROCESSED_DIR / "xray_data.npy")
xray_labels = np.load(PREPROCESSED_DIR / "xray_labels.npy")

# Load trained CNN model
model = tf.keras.models.load_model(MODEL_DIR)

print("Model loaded successfully")
print("X-ray data shape:", xray_data.shape)


In [None]:

import tensorflow.keras.backend as K

def get_gradcam_heatmap(model, img, class_idx):
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(index=-3).output, model.output]
    )
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img[None,...])
        loss = predictions[:, class_idx]
    grads = tape.gradient(loss, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0,1,2))
    conv_outputs = conv_outputs[0]
    heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
    heatmap = tf.maximum(heatmap, 0) / (tf.reduce_max(heatmap) + 1e-8)
    return heatmap.numpy()

# Visualize Grad-CAM for first 3 test images
for i in range(3):
    img = xray_data[i]
    class_idx = np.argmax(xray_labels[i])
    heatmap = get_gradcam_heatmap(model, img, class_idx)
    
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1)
    plt.imshow(img[...,0], cmap='gray')
    plt.title("Original X-ray")
    plt.axis('off')
    
    plt.subplot(1,2,2)
    plt.imshow(img[...,0], cmap='gray')
    plt.imshow(heatmap, cmap='jet', alpha=0.5)
    plt.title("Grad-CAM Overlay")
    plt.axis('off')
    plt.show()


In [None]:

import shap

# Use a small subset for SHAP to speed up computation
background = xray_data[:10]
test_imgs = xray_data[10:13]

# Create DeepExplainer
explainer = shap.DeepExplainer(model, background)
shap_values = explainer.shap_values(test_imgs)

# Visualize SHAP values for each class
for i, img in enumerate(test_imgs):
    plt.figure(figsize=(8,4))
    for c in range(len(shap_values)):
        plt.subplot(1, len(shap_values), c+1)
        shap_img = shap_values[c][i][...,0]
        plt.imshow(img[...,0], cmap='gray')
        plt.imshow(shap_img, cmap='jet', alpha=0.5)
        plt.title(f"Class {c} SHAP")
        plt.axis('off')
    plt.show()
    