# Explaining the Predictions of a Convolutional Neural Network on Head CT Images

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Sulam-Group/h-shap/blob/jaco/siim-ml-tutorial/demo/RSNA_ICH_detection/explain.ipynb)

## Task

Explain the predictions of a Convolutional Neural Network (CNN) trained to predict the presence of intracranial hemorrhage on the [RSNA 2019 Brain CT Hemorrhage Challenge dataset](https://www.kaggle.com/competitions/rsna-intracranial-hemorrhage-detection/data).

## Requirements

1. Basic understanding of machine learning and deep learning.
2. Programming in Python.

## Learning objectives

1. Load a pre-trained classifier in PyTorch.
2. Use the classifier to predict the presence of hemorrhage in test images.
3. Explain the predictions of the classifier to detect intracranial hemorrhage.

## Acknowledgements

This Jupyter Notebook was based on code by Jacopo Teneggi ([jtenegg1@jhu.edu](mailto:jtenegg1@jhu.edu)).

## Load the pretrained model

Here, we load a pretrained model which was trained on the [RSNA 2019 Brain CT Hemorrhage Challenge dataset](https://www.kaggle.com/competitions/rsna-intracranial-hemorrhage-detection/data). The model is trained on the binary classification problem of predicting `1` when an images contains any type of hemorrhage, or `0` when the image is healthy.

This step requires PyTorch. Follow the instructions at [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) to install PyTorch. Note that, although not required, this Jupyter Notebook supports execution on a GPU. When running on CPU, expect long runtimes.

In [None]:
import torch
import torchvision.transforms.functional as tf

# Define the device to run on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained model
torch.set_grad_enabled(False)
model = torch.hub.load(
    "Sulam-Group/h-shap:jaco/siim-ml-tutorial", "rsnahemorrhagenet", trust_repo="check"
)
model = model.to(device)
model.eval()

## Use the Model to Predict on Images

Here, we use the pretrained model to predict on 4 positive test images from the [CQ500 dataset](http://headctstudy.qure.ai/dataset). 

Ground-truth annotations of the bleeds are provided by the [BHX extension](https://physionet.org/content/bhx-brain-bounding-box/1.1/) dataset, and are highlighted with red solid lines in the images.

In [None]:
import io
import json
import requests as req
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm


def download(url: str) -> io.BytesIO:
    """
    A helper function to download an object from a url.

    Parameters:
    -----------
    url: str
        The url to download the object from.
    """
    res = req.get(url)
    res.raise_for_status()
    return io.BytesIO(res.content)


def window(img: np.ndarray, WL: int, WW: int) -> np.ndarray:
    """
    A function that windows the values of an image at level
    `WL` with width `WW` (in Hounsfield Units).

    Parameters:
    -----------
    img: np.ndarray
        The image to window.
    WL: int
        The window level.
    WW: int
        The window width.
    """
    image_min = WL - WW // 2
    image_max = WL + WW // 2
    img[img < image_min] = image_min
    img[img > image_max] = image_max

    img = (img - image_min) / (image_max - image_min)

    return img


def annotate(annotation_df: pd.DataFrame, ax: plt.Axes) -> None:
    """
    A helper function to annotate an image with its ground-truth
    annotations.

    Parameters:
    -----------
    annotation_df: pd.DataFrame
        The DataFrame containing the annotations.
    ax: plt.Axes
        The Axes containing the image to annotate.
    """
    for _, annotation_row in annotation_df.iterrows():
        annotation = annotation_row["data"].replace("'", '"')
        annotation = json.loads(annotation)
        bbox_x = annotation["x"]
        bbox_y = annotation["y"]
        bbox_width = annotation["width"]
        bbox_height = annotation["height"]

        bbox = patches.Rectangle(
            (bbox_x, bbox_y),
            bbox_width,
            bbox_height,
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
        ax.add_patch(bbox)


# Download images and ground truth annotations;
# images are windowed with the standard brain
# setting, i.e. WW = 80 and WL = 40 .
base_url = "https://github.com/Sulam-Group/h-shap/blob/jaco/siim-ml-tutorial/demo/RSNA_ICH_detection/images"
sop_ids = [
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713007.1700822",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713021.1704518",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713469.1816851",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713940.1946656",
]
images = np.stack(
    [
        window(
            np.load(download(f"{base_url}/{sop_id}.npy?raw=true")).astype(np.float32),
            WL=40,
            WW=80,
        )
        for sop_id in tqdm(sop_ids)
    ]
)
gt = pd.read_csv(f"{base_url}/gt.csv?raw=true")

# Visualize images and ground truth annotations,
# predict the presence of hemorrhage
# in each image using the pretrained model.
_, axes = plt.subplots(1, len(images), figsize=(16, 9))
for i in range(len(images)):
    # Preprocess image for prediction
    x = images[i]
    x = tf.to_tensor(x)
    x = x.unsqueeze(1).repeat(1, 3, 1, 1)
    x = tf.normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    # Use the model to predict the presence of hemorrhage
    x = x.to(device)
    output = model(x)
    prediction = (output > 0.5).long()

    # Visualize image and ground truth annotations
    ax = axes[i]
    ax.imshow(images[i], cmap="gray")

    annotation_df = gt[gt["SOPInstanceUID"] == sop_ids[i]]
    annotate(annotation_df, ax)

    ax.set_title(f"Prediction: {prediction.item()}")
    ax.axis("off")

plt.show()

## Explain Model Predictions Using `h-Shap`

[`h-shap`](https://github.com/Sulam-Group/h-shap) is a Python package that provides an exact, fast, and hierarchical implmentation of the [Shapley value](https://en.wikipedia.org/wiki/Shapley_value) to explain model predictions on images.

`h-shap` produces **saliency map**---heatmaps where the intensity of a pixel represents its importance towards the model's prediction.

For more information on `h-shap`, refer to the paper [_"Fast Hierarchical Games for Image Explanations"_](https://ieeexplore.ieee.org/document/9826424).

In [None]:
# Install h-shap
!python -m pip install h-shap --upgrade

In [None]:
from hshap import Explainer
from skimage.filters import threshold_otsu
from time import time

# Load the reference value used
# to mask features
reference = torch.load(download(f"{base_url}/reference.pt?raw=true"))
reference = reference.to(device)

# Initialize the explainer and its parameters
explainer = Explainer(model=model, background=reference)

s = 64
R = np.linspace(0, s, 4, endpoint=False)
A = np.linspace(0, 2 * np.pi, 8, endpoint=False)

fig, axes = plt.subplots(1, len(images), figsize=(16, 9))
for i in range(len(images)):
    # Preprocess image for explanation
    x = images[i]
    x = tf.to_tensor(x)
    x = x.repeat(3, 1, 1)
    x = tf.normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    x = x.to(device)

    # Use h-shap to explain the prediction of
    # the model on the image
    print(f"Explaining image {i+1} ...", end=" ")
    t0 = time()
    explanation = explainer.cycle_explain(
        x=x,
        label=0,
        s=s,
        R=R,
        A=A,
        threshold_mode="absolute",
        threshold=0.0,
        softmax_activation=False,
        batch_size=2,
        binary_map=False,
    )
    print(f"done in {time() - t0:.2f} s")

    # Filter explanation using Otsu's
    # method to remove noise
    explanation = explanation.numpy()
    _t = threshold_otsu(explanation.flatten())
    explanation = explanation * (explanation > _t)
    abs_values = np.abs(explanation.flatten())
    _max = np.nanpercentile(abs_values, 99)

    # Visualize image, annotation, and explanation
    ax = axes[i]
    ax.imshow(images[i], cmap="gray")
    ax.imshow(
        explanation.squeeze(),
        cmap="bwr",
        vmin=-_max,
        vmax=_max,
        alpha=0.5,
    )

    annotation_df = gt[gt["SOPInstanceUID"] == sop_ids[i]]
    annotate(annotation_df, ax)

    ax.set_title("Explanation")
    ax.axis("off")

plt.show()