# Explaining the Predictions of a Convolutional Neural Network on Head CT Images

## Task

Explain the predictions of a Convolutional Neural Network (CNN) trained to predict the presence of intracranial hemorrhage on the RSNA 2019 Brain CT Hemorrhage Challenge dataset. Data is publicly available at [https://www.kaggle.com/competitions/rsna-intracranial-hemorrhage-detection/data](https://www.kaggle.com/competitions/rsna-intracranial-hemorrhage-detection/data).

## Requirements

1. Basic understanding of machine learning and deep learning.
2. Programming in Python.

## Learning objectives

1. Load a pre-trained classifier in PyTorch.
2. Use the classifier to predict the presence of hemorrhage in test images.
3. Explain the predictions of the classifier to detect intracranial hemorrhage.

## Acknowledgements

This Jupyter Notebook was based on code by Jacopo Teneggi ([jtenegg1@jhu.edu](mailto:jtenegg1@jhu.edu)).

## Load the pretrained model

This step requires PyTorch. Follow the instructions at [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) to install PyTorch. Note that, although not required, this Jupyter Notebook supports execution on a GPU.

In [None]:
import torch
import torchvision.transforms.functional as tf

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_grad_enabled(False)
model = torch.hub.load(
    "Sulam-Group/h-shap:jaco/siim-ml-tutorial",
    "rsnahemorrhagenet",
    pretrained=True,
    trust_repo="check",
)
model = model.to(device)
model.eval()

## Use the Model to Predict on Images

In [None]:
import io
import json
import requests as req
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm


def download(url: str) -> io.BytesIO:
    res = req.get(url)
    res.raise_for_status()
    return io.BytesIO(res.content)


def window(img: np.ndarray, WL: int, WW: int) -> np.ndarray:
    image_min = WL - WW // 2
    image_max = WL + WW // 2
    img[img < image_min] = image_min
    img[img > image_max] = image_max

    img = (img - image_min) / (image_max - image_min)

    return img


def annotate(annotation_df: pd.DataFrame, ax: plt.Axes) -> None:
    for _, annotation_row in annotation_df.iterrows():
        annotation = annotation_row["data"].replace("'", '"')
        annotation = json.loads(annotation)
        bbox_x = annotation["x"]
        bbox_y = annotation["y"]
        bbox_width = annotation["width"]
        bbox_height = annotation["height"]

        bbox = patches.Rectangle(
            (bbox_x, bbox_y),
            bbox_width,
            bbox_height,
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
        ax.add_patch(bbox)


# Download images and ground truth annotations
base_url = "https://github.com/Sulam-Group/h-shap/blob/jaco/siim-ml-tutorial/demo/RSNA_ICH_detection/images"
sop_ids = [
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713007.1700822",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713021.1704518",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713469.1816851",
    "1.2.276.0.7230010.3.1.4.296485376.1.1521713940.1946656",
]
images = np.stack(
    [
        window(
            np.load(download(f"{base_url}/{sop_id}.npy?raw=true")).astype(np.float32),
            WL=40,
            WW=80,
        )
        for sop_id in tqdm(sop_ids)
    ]
)
gt = pd.read_csv(f"{base_url}/gt.csv?raw=true")

# Visualize images and ground truth annotations,
# predict the presence of hemorrhage
# in each image using the pretrained model.
_, axes = plt.subplots(1, len(images), figsize=(16, 9))
for i in range(len(images)):
    # Preprocess image for prediction
    x = images[i]
    x = tf.to_tensor(x)
    x = x.unsqueeze(1).repeat(1, 3, 1, 1)
    x = tf.normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    # Use the model to predict the presence of hemorrhage
    x = x.to(device)
    output = model(x)
    prediction = (output > 0.5).long()

    # Visualize image and ground truth annotations
    ax = axes[i]
    ax.imshow(images[i], cmap="gray")

    annotation_df = gt[gt["SOPInstanceUID"] == sop_ids[i]]
    annotate(annotation_df, ax)

    ax.set_title(f"Prediction: {prediction.item()}")
    ax.axis("off")

plt.show()