# Explaining the Predictions of a Convolutional Neural Network on Blood Smears with Different Methods

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Sulam-Group/h-shap/blob/jaco/chembe_ml/demo/BBBC041/explain_compare.ipynb)

## Task

To explain the predictions of a convolutional neural network trained to predict the presence of trophozoites (o.e., malaria infected cells) in human blood smears. The network was trained on the [BBBC041 dataset](https://data.broadinstitute.org/bbbc/BBBC041/).

## Requirements



## Load the Model

In [None]:
import torch

# Define the device to run on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained model
torch.set_grad_enabled(False)
model = torch.hub.load(
    "Sulam-Group/h-shap:jaco/chembe_ml", "bbbc041trophozoitenet", trust_repo="check"
)
model = model.to(device)
model.eval()

## Use the Model to Predict on Images

In [None]:
import io
import requests as req
import pandas as pd
import torch
import torchvision.transforms as t
import torchvision.transforms.functional as tf
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image


def download(url: str) -> io.BytesIO:
    """
    A helper function to download an object from a url.

    Parameters:
    -----------
    url: str
        The url to download the object from.
    """
    res = req.get(url)
    res.raise_for_status()
    return io.BytesIO(res.content)


def annotate(cell_df: pd.DataFrame, ax: plt.Axes) -> None:
    """
    A helper function to annotate an image with its ground-truth
    annotations.

    Parameters:
    -----------
    cell_Df: pd.DataFrame
        The DataFrame containing the cell bounding boxes.
    ax: plt.Axes
        The Axes containing the image to annotate.
    """
    for c in cell_df:
        category = c["category"]
        if category == "trophozoite":
            bbox = c["bounding_box"]
            ul_r = bbox["minimum"]["r"]
            ul_c = bbox["minimum"]["c"]
            br_r = bbox["maximum"]["r"]
            br_c = bbox["maximum"]["c"]
            w = abs(br_c - ul_c)
            h = abs(br_r - ul_r)
            bbox = patches.Rectangle(
                (ul_c, ul_r), w, h, linewidth=2, edgecolor="g", facecolor="none"
            )
            ax.add_patch(bbox)


# Download images and ground truth annotations
base_url = (
    "https://raw.githubusercontent.com/Sulam-Group/h-shap/jaco/chembe_ml/demo/BBBC041"
)
image_ids = [
    "1f8f08ea-b5b3-4f68-94d4-3cc071b7dce8",
    "1fdf99d1-a494-4174-bcd7-efbe457ab899",
    "2aeb7b85-7df9-4a63-8d37-2eecaaa190e7",
    "2bed2796-75a9-4d93-bb2d-e3ccbe1a5cbe",
    "2f6224be-50d0-4e85-94ef-88315df561b6",
]
images = [
    Image.open(download(f"{base_url}/images/{image_id}.png")) for image_id in image_ids
]
annotations = pd.read_json(download(f"{base_url}/annotations.json"))
annotations.set_index("image", inplace=True)

# Define preprocessing transform
transform = t.Compose(
    [t.ToTensor(), t.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)

# Visualize images, ground truth annotations, and predictions
_, axes = plt.subplots(1, len(images), figsize=(16, 9))
for i in range(len(images)):
    # Preprocess image for prediction
    x = images[i]
    x = transform(x)
    x = x.unsqueeze(0)

    # Use the model to precict the presence of trophozoites
    x = x.to(device)
    output = model(x)
    prediction = torch.argmax(output, dim=1)

    # Visualie image, ground truth annotations, and prediction
    ax = axes[i]
    ax.imshow(images[i])

    cell_df = annotations.at[f"{image_ids[i]}.png", "objects"]
    annotate(cell_df, ax)

    ax.set_title(f"Prediction: {prediction.item()}")
    ax.axis("off")
plt.show()

## Install Explainers

In [None]:
!python -m pip install grad-cam --upgrade
!python -m pip install lime --upgrade
!python -m pip install shap --upgrade
!python -m pip install h-shap --upgrade

## Load Reference Value

In [None]:
# Load the reference value used to mask features
reference = torch.load(download(f"{base_url}/reference.pt?raw=true"))
reference = reference.to(device)

# Check that the prediction of the model on the reference is 0
reference_output = model(reference.unsqueeze(0))
reference_prediction = torch.argmax(reference_output, dim=1).cpu().item()

# Visualize reference
denorm_reference = tf.normalize(
    reference.cpu(),
    mean=[-0.485, -0.456, -0.406],
    std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
)

_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
ax.imshow(denorm_reference.permute(1, 2, 0))
ax.axis("off")
ax.set_title(f"Reference value (prediction = {reference_prediction})")
plt.show()

## Define Explainers and Explanation Functions

In [None]:
import numpy as np
import torch.nn.functional as f
import shap
import hshap
from pytorch_grad_cam import GradCAM
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm

# Grad-CAM initialization function
def _cam_init():
    explainer = GradCAM(model=model, target_layers=[model.layer4[-1]], use_cuda=True)
    return explainer


# LIME initialization function
def _lime_init():
    explainer = lime_image.LimeImageExplainer()
    return explainer


# PartitionExplainer initialization function
def _partexp_init():
    partexp_reference = reference.permute(1, 2, 0)
    partexp_reference = partexp_reference.cpu().numpy()
    masker = shap.maskers.Image(partexp_reference)

    def _f(x):
        x = torch.tensor(x).float()
        x = x.permute(0, 3, 1, 2)
        x = x.to(device)
        output = model(x).cpu().numpy()[..., -1]
        return output

    explainer = shap.Explainer(_f, masker)
    return explainer


# h-Shap initialization function
def _hshap_init():
    explainer = hshap.Explainer(model=model, background=reference)
    return explainer


# Grad-CAM explanation function
def _cam_explain(explainer, image):
    image_t = transform(image)
    image_t = image_t.to(device)
    image_t = image_t.unsqueeze(0)

    explanation = explainer(input_tensor=image_t).squeeze()
    return explanation


# LIME explanation function
def _lime_explain(explainer, image):
    def _f(x):
        x = torch.stack(tuple(transform(u) for u in x), dim=0)
        x = x.to(device)
        output = model(x)
        output = f.softmax(output, dim=1)
        output = output.cpu().numpy()
        return output

    segmentation_fn = SegmentationAlgorithm(
        "quickshift", kernel_size=4, max_dist=200, ratio=0.2
    )

    image_rgb = image.convert("RGB")
    image_rgb = np.array(image_rgb)

    explanation = explainer.explain_instance(
        image_rgb, _f, top_labels=1, num_samples=100, segmentation_fn=segmentation_fn
    )
    _, explanation = explanation.get_image_and_mask(
        explanation.top_labels[0],
        positive_only=True,
        num_features=len(explanation.segments),
        hide_rest=False,
    )
    return explanation


# PartitionExplainer explanation function
def _partexp_explain(explainer, image):
    image_t = transform(image)
    image_npy = image_t.permute(1, 2, 0).numpy()
    image_npy = np.expand_dims(image_npy, axis=0)

    max_evals = 128
    explanation = explainer(image_npy, max_evals=max_evals, fixed_context=0)
    explanation = explanation.values[0].sum(axis=-1)
    return explanation


# h-Shap explanation function
def _hshap_explain(explainer, image):
    image_t = transform(image)
    image_t = image_t.to(device)

    min_s = 80
    threshold_mode, threshold_value = "relative", 70
    explanation = explainer.explain(
        image,
        label=1,
        s=min_s,
        threshold_mode=threshold_mode,
        threshold_value=threshold_value,
    )
    return explanation


explainers = [
    {"name": "Grad-CAM", "explainer": _cam_init(), "explain": _cam_explain},
    {"name": "LIME", "explainer": _lime_init(), "explain": _lime_init},
    {
        "name": "PartitionExplainer",
        "explainer": _partexp_init(),
        "explain": _partexp_explain,
    },
    {"name": "h-Shap", "explainer": _hshap_init(), "explain": _hshap_explain},
]