# Grad-CAM Visualization for Sickle Cell Classification

## Introduction

This notebook demonstrates the use of Gradient-weighted Class Activation Mapping (Grad-CAM) to visualize which regions of the input images are most important for the model's predictions. Grad-CAM helps us understand and interpret the decision-making process of our deep learning model, which is crucial for building trust in medical diagnosis applications.

## What is Grad-CAM?

Grad-CAM is a technique that produces a heatmap showing the importance of different regions in the image for the model's prediction. It works by:

1. Computing the gradient of the class score (before softmax) with respect to the feature maps of the last convolutional layer
2. Pooling these gradients to get importance weights
3. Computing a weighted combination of the feature maps
4. Applying ReLU to show only features that have a positive influence on the class of interest

## Why is this important for medical imaging?

1. **Model Interpretability**: Understand what features the model is focusing on
2. **Error Analysis**: Identify potential biases or incorrect focus areas
3. **Clinical Validation**: Ensure the model is looking at biologically relevant features
4. **Trust Building**: Help clinicians understand and trust the model's predictions

## Implementation Overview

1. **Model Setup**:
   - Load the pre-trained ResNet50 model
   - Create a model that maps the input image to the activations of the last convolutional layer and the output predictions

2. **Grad-CAM Algorithm**:
   - Forward pass an image through the network
   - Compute gradients of the top predicted class score with respect to the feature maps
   - Generate the heatmap by weighting the feature maps with the corresponding gradients
   - Apply ReLU to the heatmap to consider only features that have a positive influence

3. **Visualization**:
   - Superimpose the heatmap on the original image
   - Compare model's attention with clinical knowledge

## Expected Outputs

For each test image, we'll display:
1. The original image
2. The Grad-CAM heatmap
3. The superimposed visualization
4. The model's prediction and confidence

## Key Considerations

1. **Clinical Relevance**: The heatmaps should highlight clinically relevant features (e.g., cell morphology)
2. **False Positives/Negatives**: Pay special attention to misclassified examples
3. **Artifacts**: Check if the model is focusing on image artifacts rather than biological features

## Next Steps After Visualization

1. **Error Analysis**: Identify patterns in model mistakes
2. **Model Improvement**: Use insights to improve data augmentation or model architecture
3. **Quantitative Evaluation**: Consider quantitative metrics for model interpretability
4. **Clinical Review**: Have domain experts review the heatmaps for biological plausibility

## References

1. Selvaraju, R. R., et al. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." *IEEE International Conference on Computer Vision (ICCV)*, 2017.
2. Adebayo, J., et al. "Sanity Checks for Saliency Maps." *NeurIPS*, 2018.
3. Holzinger, A., et al. "Causability and Explainability of AI in Medicine." *Wiley Interdisciplinary Reviews: Data Mining and Knowledge Discovery*, 2019.

## Usage

Run each cell sequentially to:
1. Load the model and setup Grad-CAM
2. Process test images
3. Generate and visualize the attention maps
4. Interpret the results

> **Note**: This visualization is for research/educational purposes and should not be used for clinical decision making without proper validation.

In [15]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import cv2
import os
import json
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions

In [17]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import json
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input
import matplotlib.pyplot as plt
import cv2

# Load the model
def load_trained_model():
    try:
        # Try to load the model
        model_path = os.path.join('..', '..', 'models', 'resnet50_model.h5')
        print(f"Loading model from: {os.path.abspath(model_path)}")
        
        if not os.path.exists(model_path):
            print("Model file not found. Please check the path.")
            # Try to find any .h5 files in the models directory
            models_dir = os.path.join('..', '..', 'models')
            if os.path.exists(models_dir):
                print(f"Available files in models directory: {os.listdir(models_dir)}")
            return None
        
        # Load the model
        model = load_model(model_path)
        print("Model loaded successfully!")
        print(f"Model input shape: {model.input_shape}")
        print(f"Model output shape: {model.output_shape}")
        return model
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

# Load class indices
def load_class_indices():
    try:
        class_indices_path = os.path.join('..', '..', 'reports', 'class_indices.json')
        print(f"Loading class indices from: {os.path.abspath(class_indices_path)}")
        
        if not os.path.exists(class_indices_path):
            print("Class indices file not found. Using default class mapping.")
            return {'0': 'Negative', '1': 'Positive'}
            
        with open(class_indices_path, 'r') as f:
            class_indices = json.load(f)
        
        # Create reverse mapping (index to class name)
        idx_to_class = {int(v): k for k, v in class_indices.items()}
        print(f"Loaded class mapping: {idx_to_class}")
        return idx_to_class
        
    except Exception as e:
        print(f"Error loading class indices: {str(e)}")
        return {0: 'Negative', 1: 'Positive'}  # Default fallback

# Load the model and class indices
model = load_trained_model()
idx_to_class = load_class_indices()

if model is None:
    print("\nFailed to load the model. Please check if the model file exists and is accessible.")
    print("You may need to train the model first by running the training notebook.")
else:
    print("\nModel and class indices loaded successfully!")
    print(f"Available classes: {idx_to_class}")

Loading model from: /Applications/Projects/Sickle Cell Classifer/models/resnet50_model.h5
Model file not found. Please check the path.
Available files in models directory: ['resnet50_final.h5', 'resnet50_best.h5']
Loading class indices from: /Applications/Projects/Sickle Cell Classifer/reports/class_indices.json
Loaded class mapping: {0: 'Negative', 1: 'Positive'}

Failed to load the model. Please check if the model file exists and is accessible.
You may need to train the model first by running the training notebook.


In [16]:
# Load the saved model
model_path = os.path.join('..', '..', 'models', 'resnet50_model.h5')
model = tf.keras.models.load_model(model_path)

# Load class indices
with open(os.path.join('..', '..', 'reports', 'class_indices.json'), 'r') as f:
    class_indices = json.load(f)
    
# Reverse the class indices for lookup
idx_to_class = {v: k for k, v in class_indices.items()}

# Print class mapping for reference
print("Class mapping:", idx_to_class)

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = '../../models/resnet50_model.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [None]:
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    # Create a model that maps the input image to the activations of the last conv layer
    grad_model = Model(
        inputs=model.inputs,
        outputs=[model.get_layer(last_conv_layer_name).output, model.output]
    )

    # Compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]

    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, conv_outputs)

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

In [None]:
def load_and_preprocess_image(img_path, target_size=(224, 224)):
    # Load and preprocess the image
    img = image.load_img(img_path, target_size=target_size)
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_input(img_array)
    return img_array, img

def save_and_display_gradcam(img_path, heatmap, cam_path="cam.jpg", alpha=0.4):
    # Load the original image
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Resize heatmap to be the same size as the original image
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    
    # Convert heatmap to RGB
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    # Superimpose the heatmap on the original image
    superimposed_img = heatmap * alpha + img
    superimposed_img = np.clip(superimposed_img, 0, 255).astype('uint8')
    
    # Display the result
    fig, ax = plt.subplots(1, 3, figsize=(20, 5))
    
    ax[0].imshow(img)
    ax[0].set_title('Original Image')
    ax[0].axis('off')
    
    ax[1].imshow(heatmap)
    ax[1].set_title('GradCAM Heatmap')
    ax[1].axis('off')
    
    ax[2].imshow(superimposed_img)
    ax[2].set_title('Superimposed')
    ax[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(cam_path)
    plt.show()

In [None]:
def run_gradcam(image_path, model, last_conv_layer_name='conv5_block3_out'):
    # Preprocess the image
    img_array, original_img = load_and_preprocess_image(image_path)
    
    # Make prediction
    preds = model.predict(img_array)
    pred_class = np.argmax(preds[0])
    confidence = np.max(preds[0])
    
    print(f"Predicted class: {idx_to_class[pred_class]} (confidence: {confidence:.2f})")
    
    # Generate class activation heatmap
    heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)
    
    # Save and display the GradCAM output
    save_and_display_gradcam(image_path, heatmap)
    
    return heatmap

In [None]:
# Example usage with a test image
test_image_dir = os.path.join('..', '..', 'Data')  # Updated path to point to the Data directory

# Check if test directory exists
if not os.path.exists(test_image_dir):
    raise FileNotFoundError(f"Data directory not found at: {os.path.abspath(test_image_dir)}")

print(f"Looking for test images in: {os.path.abspath(test_image_dir)}")

# Get all class directories (Positive and Negative)
class_dirs = [d for d in os.listdir(test_image_dir) 
             if os.path.isdir(os.path.join(test_image_dir, d)) and 
             d in ['Positive', 'Negative']]  # Only include these specific directories

if not class_dirs:
    print("No class directories found in the data directory.")
else:
    print(f"Found {len(class_dirs)} classes: {', '.join(class_dirs)}")

# Process one image from each class
for class_dir in class_dirs:
    class_path = os.path.join(test_image_dir, class_dir)
    try:
        # Get all image files in the class directory
        image_files = [
            f for f in os.listdir(class_path) 
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))
        ]
        
        if not image_files:
            print(f"\nNo images found in class directory: {class_dir}")
            continue
            
        # Take the first image
        test_image_path = os.path.join(class_path, image_files[0])
        
        # Verify the image exists and is readable
        if not os.path.exists(test_image_path):
            print(f"\nImage not found: {test_image_path}")
            continue
            
        print(f"\n{'='*50}")
        print(f"Analyzing image from class: {class_dir}")
        print(f"Image: {image_files[0]}")
        print(f"Full path: {os.path.abspath(test_image_path)}")
        
        # Run GradCAM visualization
        try:
            _ = run_gradcam(test_image_path, model)
            print(f"Successfully processed: {image_files[0]}")
        except Exception as e:
            print(f"Error processing {image_files[0]}: {str(e)}")
            import traceback
            traceback.print_exc()
            
    except Exception as e:
        print(f"Error processing class {class_dir}: {str(e)}")
        import traceback
        traceback.print_exc()
        continue

print("\nGradCAM visualization complete!")

Looking for test images in: /Applications/Projects/Sickle Cell Classifer/Data
Found 2 classes: Positive, Negative

No images found in class directory: Positive

No images found in class directory: Negative

GradCAM visualization complete!


In [None]:
# Example usage with a test image
data_dir = os.path.join('..', '..', 'Data')  # Base data directory
class_dirs = ['Positive', 'Negative']  # Expected class directories

print(f"Looking for images in: {os.path.abspath(data_dir)}")

# Process one image from each class
for class_dir in class_dirs:
    class_path = os.path.join(data_dir, class_dir, 'Unlabelled')  # Added 'Unlabelled' subdirectory
    
    if not os.path.exists(class_path):
        print(f"\nClass directory not found: {class_path}")
        continue
        
    try:
        # Get all image files in the Unlabelled subdirectory
        image_files = [
            f for f in os.listdir(class_path) 
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))
        ]
        
        if not image_files:
            print(f"\nNo images found in: {class_path}")
            continue
            
        # Take the first image
        test_image_path = os.path.join(class_path, image_files[0])
        
        # Verify the image exists and is readable
        if not os.path.exists(test_image_path):
            print(f"\nImage not found: {test_image_path}")
            continue
            
        print(f"\n{'='*50}")
        print(f"Analyzing image from class: {class_dir}")
        print(f"Image: {image_files[0]}")
        print(f"Full path: {os.path.abspath(test_image_path)}")
        
        # Run GradCAM visualization
        try:
            _ = run_gradcam(test_image_path, model)
            print(f"Successfully processed: {image_files[0]}")
        except Exception as e:
            print(f"Error processing {image_files[0]}: {str(e)}")
            import traceback
            traceback.print_exc()
            
    except Exception as e:
        print(f"Error processing class {class_dir}: {str(e)}")
        import traceback
        traceback.print_exc()
        continue

print("\nGradCAM visualization complete!")

Looking for images in: /Applications/Projects/Sickle Cell Classifer/Data

Analyzing image from class: Positive
Image: 63.jpg
Full path: /Applications/Projects/Sickle Cell Classifer/Data/Positive/Unlabelled/63.jpg
Error processing 63.jpg: name 'model' is not defined

Class directory not found: ../../Data/Negative/Unlabelled

GradCAM visualization complete!


Traceback (most recent call last):
  File "/var/folders/xz/1mhz2xvx29xdpqhg3kk69fph0000gn/T/ipykernel_50202/1083704799.py", line 41, in <module>
    _ = run_gradcam(test_image_path, model)
                                     ^^^^^
NameError: name 'model' is not defined


In [None]:
# First, verify the model is loaded correctly
print("Verifying model...")
if model is None:
    print("Error: Model is not loaded!")
else:
    print("Model loaded successfully")
    print(f"Model input shape: {model.input_shape}")
    print(f"Model output shape: {model.output_shape}")

# Example usage with a test image
data_dir = os.path.abspath(os.path.join('..', '..', 'Data'))  # Get absolute path
class_dirs = ['Positive', 'Negative']  # Expected class directories

print(f"\nLooking for images in: {data_dir}")

# Process one image from each class
for class_dir in class_dirs:
    class_path = os.path.join(data_dir, class_dir, 'Unlabelled')
    print(f"\nChecking directory: {class_path}")
    
    if not os.path.exists(class_path):
        print(f"Directory does not exist: {class_path}")
        # List contents of parent directory to help debug
        parent_dir = os.path.dirname(class_path)
        if os.path.exists(parent_dir):
            print(f"Contents of {parent_dir}: {os.listdir(parent_dir)}")
        continue
        
    try:
        # List all files in the directory
        all_files = os.listdir(class_path)
        print(f"Found {len(all_files)} items in directory")
        
        # Filter for image files
        image_files = [
            f for f in all_files 
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))
        ]
        
        print(f"Found {len(image_files)} image files")
        
        if not image_files:
            print(f"No image files found in: {class_path}")
            print(f"File types present: {set(os.path.splitext(f)[1] for f in all_files)}")
            continue
            
        # Take the first image
        test_image_path = os.path.join(class_path, image_files[0])
        print(f"Selected image: {test_image_path}")
        
        # Verify the image exists and is readable
        if not os.path.exists(test_image_path):
            print(f"Image file not found: {test_image_path}")
            continue
            
        print(f"\n{'='*50}")
        print(f"Analyzing image from class: {class_dir}")
        print(f"Image: {image_files[0]}")
        print(f"Full path: {test_image_path}")
        
        # Run GradCAM visualization
        try:
            print("Running GradCAM...")
            _ = run_gradcam(test_image_path, model)
            print(f"Successfully processed: {image_files[0]}")
        except Exception as e:
            print(f"Error processing {image_files[0]}: {str(e)}")
            import traceback
            traceback.print_exc()
            
    except Exception as e:
        print(f"Error processing class {class_dir}: {str(e)}")
        import traceback
        traceback.print_exc()
        continue

print("\nGradCAM visualization complete!")

Verifying model...


NameError: name 'model' is not defined