# Explaining Predictions of AI Models in Radiology

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/Sulam-Group/AI-Deep-Learning-Lab-2023/blob/bbjt-nb-fairness_interpretability/sessions/ai-fairness/explainability_solution.ipynb)

In this notebook, we will explore explainability of AI models in radiology.
We will use a pretrained model to predict the presence of intracranial hemorrhage (ICH) in Head CT scans.

We will use h-Shap to explain the predictions of the model and evaluate how well the explanations align with the ground truth segmentations provided by expert radiologists.

---

**Before we start**

1. Change Colab runtime to GPU,
2. Add a shortcut to the shared Google Drive folder: [https://drive.google.com/drive/folders/1p90aGBS8vIX54x9ytaW8h-vk4NHXDhpR?usp=sharing](https://drive.google.com/drive/folders/1p90aGBS8vIX54x9ytaW8h-vk4NHXDhpR?usp=sharing)

## Setup and Imports

In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from sklearn.metrics import roc_curve, auc
from skimage.filters import threshold_otsu
from scipy.spatial.distance import dice
from tqdm import tqdm

LAB_PATH = os.path.join("drive/MyDrive/RSNA2023-FAIRNESS-LAB")
sys.path.append(LAB_PATH)

!python -m pip install entmax
from utils import (
    get_dataset,
    get_slice_predictor,
    get_hshap_slice_explainer,
    get_ground_truth_mask,
    viz_explanations_with_annotations,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

sns.set_theme()
sns.set_context("talk")


## Compute the ROC Curve and Find an Optimal Operating Point

Here, we will use a pretrained model to predict the presence of intracranial hemorrhage within slices from CT scans of the brain.

---

Objectives:

1. Use the model to classify slices from the test set.
2. Plot the ROC curve and evaluate the AUC.
3. Pick the threshold that maximizes Youden's J statistic.

In [None]:
# Load the model
model = get_slice_predictor(device)

# Load the test dataset
dataset = get_dataset()
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# Here:
# Predict whether a slice contains any signs of ICH
# you can use the following lists to store the ground truth and predicted labels
ground_truth, logits = [], []
for i, data in enumerate(tqdm(dataloader)):
    series, _, labels, _ = data

    series = series.to(device)

    output = model(series)

    ground_truth.extend(labels.squeeze().cpu().numpy().tolist())
    logits.extend(output.squeeze().cpu().numpy().tolist())

# Here:
# 1. Compute the ROC curve
# 2. Evaluate the AUC
# 3. Find the optimal threshold that maximizes Youden's J statistic
# hint: Youden's J statistic is TPR - FPR
fpr, tpr, threshold = roc_curve(ground_truth, logits, drop_intermediate=False)
_auc = auc(fpr, tpr)
j = tpr - fpr
tj = np.argmax(j)
t = threshold[tj]

# Here:
# 1. Plot the ROC curve
# 2. Mark the optimal threshold
_, ax = plt.subplots(figsize=(4, 4))
ax.plot(fpr, tpr)
ax.scatter(fpr[tj], tpr[tj], color="r")
ax.set_xlabel("FPR")
ax.set_ylabel("TPR")
ax.set_title(f"AUC: {_auc:.2f}, t = {t:.2f}")
ax.set_xticks([0, 0.5, 1.0])
ax.set_yticks([0, 0.5, 1.0])
plt.show()

## Explain the Predictions of the Model Using Shapley Values

Here, we will use h-Shap to explain the predictions of the model on some true positive and false positive examples.

---

Objectives:

1. Select some true positive and false positive examples. You can use the `dataset.get_slice(idx)` function to retrivve slices from the dataset.
2. Use h-Shap and Grad-CAM to explain the predictions of the model.
3. Compute the Dice score between the explanations on true positive predictions and the ground truth annotations.


In [None]:
# You can use this function to visualize the slices
def _show(slices, title):
    _, ax = plt.subplots(figsize=(16, 9))
    im = make_grid(dataset.denormalize(slices), nrow=m)
    ax.imshow(im.permute(1, 2, 0))
    ax.axis("off")
    ax.set_title(title)
    plt.show()


# Here:
# 1. Threshold predictions with the optimal threshold
# 2. Randomly sample 4 true positive and 4 false positive slices
# 3. Visualize the slices
m = 4
ground_truth, logits = np.array(ground_truth), np.array(logits)
predictions = logits >= t

tp = np.where((ground_truth == 1) & (predictions == 1))[0]
fp = np.where((ground_truth == 0) & (predictions == 1))[0]

tp_idx = np.random.choice(tp, m, replace=False)
fp_idx = np.random.choice(fp, m, replace=False)

tp_slices = torch.stack([dataset.get_slice(idx) for idx in tp_idx])
fp_slices = torch.stack([dataset.get_slice(idx) for idx in fp_idx])

_show(tp_slices, "True positives")
_show(fp_slices, "False positives")

You can use the `get_ground_truth_mask` and `explain_slice` functions to get the ground truth segmentation mask and to explain model predictions, respectively, for example:
```python
slice_idx = 200
mask = get_slice_ground_truth_mask(dataset, slice_idx)
explanation = explain_slice(slice_idx)
```

In [None]:
explain_slice = get_hshap_slice_explainer(model, device)

# Here:
# 1. Explain predictions on true positive and false positive slice
# 2. Compute the Dice score between the ground truth segmentation and the explanation
# hint: threshold explanations with Otsu's method to obtain better results
fp_explanations = [explain_slice(fp_slices[idx]) for idx in tqdm(range(m))]
tp_explanations = [explain_slice(tp_slices[idx]) for idx in tqdm(range(m))]

tp_ground_truth = [get_ground_truth_mask(dataset, idx) for idx in tp_idx]

def _dice_score(explanation, ground_truth):
    threshold = threshold_otsu(explanation.flatten())
    explanation = explanation > threshold
    ground_truth = ground_truth > 0
    return dice(explanation.flatten(), ground_truth.flatten())


tp_dice_score = [
    _dice_score(explanation, ground_truth)
    for explanation, ground_truth in zip(tp_explanations, tp_ground_truth)
]

viz_explanations_with_annotations(dataset, tp_slices, tp_explanations, tp_idx, tp_dice_score)
viz_explanations_with_annotations(dataset, fp_slices, fp_explanations, fp_idx)