# Explaining Predictions of AI Models in Radiology (Bonus notebook)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/Sulam-Group/AI-Deep-Learning-Lab-2023/blob/bbjt-nb-fairness_interpretability/sessions/ai-fairness/bonus_explainability.ipynb)

In this notebook, we will explore explainability of AI models in radiology.
Differently from before, we will note use a model that predicts the presence of ICH at the examination-level instead of slice-level.

We will use h-Shap to explain the predictions of the model and retrieve the slices in the examination that contain signs of hemorrhage.

---

**Before we start**

1. Change Colab runtime to GPU,
2. Add a shortcut to the shared Google Drive folder: [https://drive.google.com/drive/folders/1p90aGBS8vIX54x9ytaW8h-vk4NHXDhpR?usp=sharing](https://drive.google.com/drive/folders/1p90aGBS8vIX54x9ytaW8h-vk4NHXDhpR?usp=sharing)

## Setup and Imports

In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm

LAB_PATH = os.path.join("drive/MyDrive/RSNA2023-FAIRNESS-LAB")
sys.path.append(LAB_PATH)

!python -m pip install entmax
from utils import (
    get_dataset,
    get_random_positive_examination,
    get_series_predictor,
    get_hshap_examination_explainer,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

sns.set_theme()
sns.set_context("talk")


## Does the Model Correctly Classify Examinations?

Here, we will verify whether the model has good generalization performance on the examination-level binary classification task.

---

Objectives:

1. Use the model to classify examinations. Use the `series=True` flag to get the model's prediction on the entire examination, for example
```python
output, _ = model(x, series=True)
```
2. Plot the ROC curve and evaluate the AUC.
3. Pick the threshold that maximizes Youden's J statistic.

In [None]:
# Load the pretained model
model = get_series_predictor(device)

# Load the test dataset
dataset = get_dataset()
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# Here:
# Predict whether an examination contains any signs of ICH
# note: these lists will store the ground truth and predicted labels
ground_truth, logits = [], []
for i, data in enumerate(tqdm(dataloader)):
    series, target, _, _ = data

    ### YOUR CODE HERE ###

# Here:
# 1. Compute the ROC curve
# 2. Evaluate the AUC
# 3. Find the optimal threshold that maximizes Youden's J statistic

### YOUR CODE HERE ###

# Here:
# Plot everything!

### YOUR CODE HERE ###

## Do the Model's Attention Weights Align with Image-level Labels?

Here, we will investigate whether the attention weights of the trained model align well with the image-level binary labels.

---

Objectives:

1. Use the `return_attention=True` flag to return the attention weights of the model, for example:
```python
output, aux = model(series, series=True, return_attention=True)
attention = aux["attention"]
```
2. Plot the ground truth image-level labels and the attention weights.

In [None]:
# Get random positive examination from dataset
patient_number, (series, target, labels, _) = get_random_positive_examination()
patient_number = patient_number.item()

# Visualize 5 random positive slices within the examination
m = 5
_, ax = plt.subplots(figsize=(16, 9))
positive_idx = np.where(labels == 1)[0]
slice_idx = np.random.choice(positive_idx, m, replace=False)
im = make_grid(dataset.denormalize(series[slice_idx]))
ax.imshow(im.permute(1, 2, 0), cmap="gray")
ax.axis("off")
ax.set_title("5 random positive slices")
plt.show()

# Here:
# Get the attention weights for the examination

### YOUR CODE HERE ###

# Here:
# Plot the ground truth slice-level labels and attention weights

### YOUR CODE HERE ###

## Explaining Examination-level Predictions with Shapley Values

Here, we will use [h-Shap](https://www.computer.org/csdl/journal/tp/2023/04/09826424/1EVdAz76rC0) to compute each slice's contribution to the model's examination-level prediction.

---

Objectives:
1. Use the `explain_examination` function to compute the Shapley values of each slice in the examination, i.e.:
```python
explanation = explain_examination(series)
```
2. Plot the ground truth image-level labels, the attention weights, and the Shapley values.

In [None]:
explain_examination = get_hshap_examination_explainer(model, device)

# Here:
# Compute the Shapley values

### YOUR CODE HERE ###

# Here:
# Plot everything!

### YOUR CODE HERE ###