# Explaining the Predictions of a Convolutional Neural Network on Head CT Images

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Sulam-Group/h-shap/blob/jaco/siim-ml-tutorial/demo/RSNA_ICH_detection/explain.ipynb)

## Task

Explain the predictions of a Convolutional Neural Network (CNN) trained to predict the presence of intracranial hemorrhage on the [RSNA 2019 Brain CT Hemorrhage Challenge dataset](https://www.kaggle.com/competitions/rsna-intracranial-hemorrhage-detection/data).

## Requirements

1. Basic understanding of machine learning and deep learning.
2. Programming in Python.

## Learning objectives

1. Load a pre-trained classifier in PyTorch.
2. Use the classifier to predict the presence of hemorrhage in test images.
3. Explain the predictions of the classifier to detect intracranial hemorrhage.

## Acknowledgements

This Jupyter Notebook was based on code by Jacopo Teneggi ([jtenegg1@jhu.edu](mailto:jtenegg1@jhu.edu)).

## Load the pretrained model

Here, we load a pretrained model which was trained on the [RSNA 2019 Brain CT Hemorrhage Challenge dataset](https://www.kaggle.com/competitions/rsna-intracranial-hemorrhage-detection/data). The model is trained on the binary classification problem of predicting `1` when an images contains any type of hemorrhage, or `0` when the image is healthy.

This step requires PyTorch. Follow the instructions at [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) to install PyTorch. Note that, although not required, this Jupyter Notebook supports execution on a GPU. When not running on GPU, expect long runtimes.

If running on [Google Colab](https://colab.research.google.com/), you can enable GPU runtime by selecting `Runtime > Change runtime type > GPU` in the `Hardware accelerator` dropdown menu. When running on GPU on a free Google Colab account, explaining model predictions should take around 50 seconds per image.

In [None]:
import torch
import torchvision.transforms.functional as tf

# Define the device to run on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained model
torch.set_grad_enabled(False)
model = torch.hub.load(
    "Sulam-Group/h-shap:jaco/siim-ml-tutorial", "rsnahemorrhagenet", trust_repo="check"
)
model = model.to(device)
model.eval()

## Use the Model to Predict on Images

Here, we use the pretrained model to predict on 4 positive test images from the [CQ500 dataset](http://headctstudy.qure.ai/dataset). 

Ground-truth annotations of the bleeds are provided by the [BHX extension](https://physionet.org/content/bhx-brain-bounding-box/1.1/) dataset, and are highlighted with red solid lines in the images.

In [None]:
import io
import json
import requests as req
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm


def download(url: str) -> io.BytesIO:
    """
    A helper function to download an object from a url.

    Parameters:
    -----------
    url: str
        The url to download the object from.
    """
    res = req.get(url)
    res.raise_for_status()
    return io.BytesIO(res.content)


def window(img: np.ndarray, WL: int, WW: int) -> np.ndarray:
    """
    A function that windows the values of an image at level
    `WL` with width `WW` (in Hounsfield Units).

    Parameters:
    -----------
    img: np.ndarray
        The image to window.
    WL: int
        The window level.
    WW: int
        The window width.
    """
    image_min = WL - WW // 2
    image_max = WL + WW // 2
    img[img < image_min] = image_min
    img[img > image_max] = image_max

    img = (img - image_min) / (image_max - image_min)

    return img


def annotate(annotation_df: pd.DataFrame, ax: plt.Axes) -> None:
    """
    A helper function to annotate an image with its ground-truth
    annotations.

    Parameters:
    -----------
    annotation_df: pd.DataFrame
        The DataFrame containing the annotations.
    ax: plt.Axes
        The Axes containing the image to annotate.
    """
    for _, annotation_row in annotation_df.iterrows():
        annotation = annotation_row["data"].replace("'", '"')
        annotation = json.loads(annotation)
        bbox_x = annotation["x"]
        bbox_y = annotation["y"]
        bbox_width = annotation["width"]
        bbox_height = annotation["height"]

        bbox = patches.Rectangle(
            (bbox_x, bbox_y),
            bbox_width,
            bbox_height,
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
        ax.add_patch(bbox)


# Download images and ground truth annotations;
# images are windowed with the standard brain
# setting, i.e. WW = 80 and WL = 40 .
base_url = "https://github.com/Sulam-Group/h-shap/blob/jaco/siim-ml-tutorial/demo/RSNA_ICH_detection/images"
sop_ids = [
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713007.1700822",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713021.1704518",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713469.1816851",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713940.1946656",
]
images = np.stack(
    [
        window(
            np.load(download(f"{base_url}/{sop_id}.npy?raw=true")).astype(np.float32),
            WL=40,
            WW=80,
        )
        for sop_id in tqdm(sop_ids)
    ]
)
gt = pd.read_csv(f"{base_url}/gt.csv?raw=true")

# Visualize images and ground truth annotations,
# predict the presence of hemorrhage
# in each image using the pretrained model.
_, axes = plt.subplots(1, len(images), figsize=(16, 9))
for i in range(len(images)):
    # Preprocess image for prediction
    x = images[i]
    x = tf.to_tensor(x)
    x = x.unsqueeze(1).repeat(1, 3, 1, 1)
    x = tf.normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    # Use the model to predict the presence of hemorrhage
    x = x.to(device)
    output = model(x)
    prediction = (output > 0.5).long()

    # Visualize image and ground truth annotations
    ax = axes[i]
    ax.imshow(images[i], cmap="gray")

    annotation_df = gt[gt["SOPInstanceUID"] == sop_ids[i]]
    annotate(annotation_df, ax)

    ax.set_title(f"Prediction: {prediction.item()}")
    ax.axis("off")

plt.show()

## Explain Model Predictions Using `h-Shap`, `PartitionExplainer`, and `Grad-CAM`

In the context on machine learning, explanations are used to highlight those features in an input that contributed the most towards a model prediction. For images, typically, explanation methods produce **saliency map**---heatmaps where the intensity of every pixel represents their importance. Here, we will showcase how to produce saliency maps using `h-Shap`, `PartitionExplainer`, and `Grad-CAM`:

* [`h-shap`](https://ieeexplore.ieee.org/document/9826424) is a Python package that provides an exact, fast, and hierarchical implmentation of the [Shapley value](https://en.wikipedia.org/wiki/Shapley_value). Source code is available on [GitHub](https://github.com/Sulam-Group/h-shap).

* `PartitionExplainer` is one of the explanation methods offered by the [`shap`](https://proceedings.neurips.cc/paper/2017/hash/8a20a8621978632d76c43dfd28b67767-Abstract.html) package. Similarly to `h-shap`, `PartitionExplainer` explores coalitions of players as an approximation to the [Owen value](https://link.springer.com/article/10.1007/s10100-009-0100-8). Source code is available on [GitHub](https://github.com/slundberg/shap).

* `pytorch-grad-cam` is a Python package that implements [`Grad-CAM`](https://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html), a gradient-based explanation method based on [Class Activation Mapping](https://openaccess.thecvf.com/content_cvpr_2016/html/Zhou_Learning_Deep_Features_CVPR_2016_paper.html). Source code is available on [GitHub](https://github.com/jacobgil/pytorch-grad-cam).

`h-shap` and `PartitionExplainer`---as other explanation methods based on game-theoretic quantities---satisfy certain desirable theoretical properties. `h-shap` is particularly useful when the task is to find a concept of interest in an image (e.g., abnormality detection). On the other hand, gradient-based methods such as `Grad-CAM` currently lack general, precise mathematical guarantees on what features they will retrieve.

For more information on `h-shap`, `PartitionExplainer`, and `Grad-CAM`, refer to the papers [_"Fast Hierarchical Games for Image Explanations"_](https://ieeexplore.ieee.org/document/9826424), [_"A Unified Approach to Interpreting Model Predictions"_](https://proceedings.neurips.cc/paper/2017/hash/8a20a8621978632d76c43dfd28b67767-Abstract.html), and [_"Grad-CAM: Visual Explanations From Deep Networks via Gradient-Based Localization"_](https://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html), respectively.

In [None]:
# Install h-shap
!python -m pip install h-shap --upgrade

# Install shap
!python -m pip install shap --upgrade

# Install pytorch-grad-cam
!python -m pip install grad-cam --upgrade

In [None]:
import shap
from hshap import Explainer
from pytorch_grad_cam import GradCAM
from skimage.filters import threshold_otsu
from time import time


def hshap_explain(hexp: Explainer, x: torch.Tensor) -> np.ndarray:
    """
    A function that uses h-shap to explain a model's prediction.

    Parameters:
    -----------
    hexp: Explainer
        The h-shap explainer to use.
    x: torch.Tensor
        The input to the model.
    """
    s = 64
    R = np.linspace(0, s, 4, endpoint=False)
    A = np.linspace(0, 2 * np.pi, 8, endpoint=False)

    explanation = hexp.cycle_explain(
        x=x,
        label=0,
        s=s,
        R=R,
        A=A,
        threshold_mode="absolute",
        threshold=0.0,
        softmax_activation=False,
        batch_size=2,
        binary_map=False,
    )
    return explanation.squeeze().numpy()


def gradcam_explain(cam: GradCAM, x: torch.Tensor) -> np.ndarray:
    """
    A function that uses grad-cam to explain a model's prediction.

    Parameters:
    -----------
    cam: GradCAM
        The grad-cam explainer to use.
    x: torch.Tensor
        The input to the model.
    """
    torch.set_grad_enabled(True)
    x = x.unsqueeze(0)
    explanation = cam(input_tensor=x)
    torch.set_grad_enabled(False)
    return explanation.squeeze()


def partexp_f(x: np.ndarray) -> np.ndarray:
    """
    A helper function to initialize the
    PartitionExplainer instance

    Parameters:
    -----------
    x: np.ndarray
        The input to the model.
    """
    x = torch.tensor(x).float()
    x = x.permute(0, 3, 1, 2)
    x = x.to(device)
    return model(x).detach().cpu().numpy()[..., -1]


def partexp_explain(partexp: shap.PartitionExplainer, x: torch.tensor) -> np.ndarray:
    """
    A function that uses PartitionExplainer to explain a model's prediction.

    Parameters:
    -----------
    partexp: PartitionExplainer
        The PartitionExplainer explainer to use.
    x: np.ndarray
        The input to the model.
    """
    mav_evals = 128

    x = x.permute(1, 2, 0).cpu().numpy()
    x = np.expand_dims(x, axis=0)
    explanation = partexp(x, max_evals=128, fixed_context=0)
    return explanation.values[0].sum(axis=-1)


def threshold_explanation(explanation):
    """
    A helper function that thresholds an
    explanation using Otsu's method.

    Parameters:
    -----------
    explanation: np.ndarray
        The explanation to threshold.
    """
    _t = threshold_otsu(explanation.flatten())
    explanation = explanation * (explanation > _t)
    abs_values = np.abs(explanation.flatten())
    _max = np.nanpercentile(abs_values, 99)
    return explanation, _max


# Load the reference value used
# to mask features
reference = torch.load(download(f"{base_url}/reference.pt?raw=true"))
reference = reference.to(device)

# Initialize h-shap
hexp = Explainer(model=model, background=reference)

# Initialize PartitionExplainer
reference = reference.permute(1, 2, 0)
reference = reference.cpu().numpy()
masker = shap.maskers.Image(reference)
partexp = shap.Explainer(partexp_f, masker)

# Initialize Grad-CAM
cam = GradCAM(model=model, target_layers=[model.encoder.layer4[-1]], use_cuda=True)

explainers = [
    ["h-Shap", hexp, hshap_explain],
    ["PartitionExplainer", partexp, partexp_explain],
    ["Grad-CAM", cam, gradcam_explain],
]
fig, axes = plt.subplots(len(explainers), len(images), figsize=(16, 9))
for i in range(len(images)):
    # Preprocess image for explanation
    x = images[i]
    x = tf.to_tensor(x)
    x = x.repeat(3, 1, 1)
    x = tf.normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    x = x.to(device)

    for j, (explainer_name, explainer, explain) in enumerate(explainers):
        print(f"Explaining image {i+1} with {explainer_name} ...", end=" ")
        t0 = time()

        # Explain the model prediction on the image
        explanation = explain(explainer, x)
        # Threshold the explanation to reduce noise
        explanation, _max = threshold_explanation(explanation)
        print(f"done in {time() - t0:.2f} seconds")

        # Visualize the image, ground-truth, and explanation
        ax = axes[j, i]
        ax.imshow(images[i], cmap="gray")
        ax.imshow(explanation, cmap="bwr", vmin=-_max, vmax=_max, alpha=0.5)

        annotation_df = gt[gt["SOPInstanceUID"] == sop_ids[i]]
        annotate(annotation_df, ax)

        ax.set_title(explainer_name)
        ax.axis("off")

plt.show()