In this demo file, we investigate methods for explainability in medical imaging using pre-trained TorchXRayVision models ( https://github.com/mlmed/torchxrayvision ).

# Setup

In [None]:
!git clone https://github.com/mlmed/gifsplanation

import numpy as np
import torchxrayvision as xrv
import skimage, torch, torchvision
import matplotlib.pyplot as plt
import sys,os
sys.path.insert(0,"./gifsplanation/")
#sys.path.insert(0,"./torchxrayvision/torchxrayvision")

from captum.attr import IntegratedGradients, Saliency, InputXGradient
import attribution

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
device

fatal: destination path 'gifsplanation' already exists and is not an empty directory.


# Investigating the models

In the following code cell, we investigate the models included in the TorchXRayVision package. Our ultimate goal is to build an app to help us visualize and interpret model predictions based on the input XRay image. This necessarily requires us to expose the model and access its layers. 

In [None]:
# first, we define a test image
!wget https://raw.githubusercontent.com/mlmed/torchxrayvision/master/tests/16747_3_1.jpg

In [None]:
#view the test image
img = skimage.io.imread("16747_3_1.jpg")
img = xrv.datasets.normalize(img, 255) # convert 8-bit image to [-1024, 1024] range
img = img.mean(2)[None, ...] # Make single color channel by averaging across the channel dimension, also add an empty batch dimension up front
plt.imshow(img[0], cmap = 'gray')
plt.show()

In [None]:
# convert that test image to a torch.tensor compatible with our models, i.e., size = (224,224)
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
img = transform(img)
img = torch.from_numpy(img)

plt.imshow(img.numpy()[0], cmap = 'gray') #resized image
plt.show()

In [None]:
# Load a model
model_names = ["densenet121-res224-all", "densenet121-res224-rsna", "densenet121-res224-nih",
               "densenet121-res224-pc", "densenet121-res224-chex", "densenet121-res224-mimic_nb", 
               "densenet121-res224-mimic_ch", "resnet50-res512-all"]
model = xrv.models.DenseNet(weights=model_names[0]).to(device)
model.eval()
print(model)

In [None]:
preds = model(img[None,...])
# the model predictions are in logits, i.e., before the soft max layer.  
dict(zip(model.pathologies,preds[0].detach().numpy()))
# This information should be presented as a horizontal bar chart

In [None]:
plt.figure(figsize=(10, 12))
plt.style.use('ggplot')
plt.barh(model.pathologies,preds[0].detach().numpy())
plt.title('Outputs')
plt.ylabel('Pathology')
plt.xlabel('Score')
plt.show()

# Gradient-based saliency maps

In [None]:
##
## It turns out, we don't actually have to implement these methods by hand. 
## The Captum package has all the functionality we need
##

# def input_gradient(input, model):
#     # Computes the gradient of the model output (logit) wrt the input image;
#     # basically local sensitivity analysis of model predictions wrt input images
#     # input is (num_batch=1,num_channels=1, height, width) to match the model input

#     model.eval() #checking we are in the appropriate mode

#     #no training, so no grads wrt params needed
#     for param in model.parameters():
#         param.requires_grad = False

#     input.requires_grad = True #size (1,1,224,224)

#     preds = model(input)
    
#     output_of_interest = preds.max()
#     ## TODO: In reality, the user may want to see the input gradient wrt other outputs / classes.
#     # This function should also accept an index for which class we want the saliency map.
#     # sorted, indices = preds.sort()
#     # output_of_interest = sorted.squeeze()[-1]
    
#     output_of_interest.backward()
#     #for sensitivities, we are not interested in the sign of the gradient, just its magnitude
#     sens = torch.abs(input.grad) #size (1,1,224,224)
#     sens = (sens - sens.min())/(sens.max()-sens.min()) #normalize local sensitivities for convenience
#     return sens

In [None]:
input = img[None,...]
input.requires_grad_()

saliency = Saliency(model)
sens_attr = saliency.attribute(input, target=preds.argmax())
# in our main implementation, the target should be user defined

In [None]:
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.imshow(img.numpy()[0], cmap = 'gray') #resized image
plt.subplot(1, 2, 2)
plt.imshow(sens_attr[0,0].numpy(), cmap=plt.cm.hot)
plt.show()

# Gradient X Input Saliency Map

In [None]:
input = img[None,...]
input.requires_grad_()

input_x_gradient = InputXGradient(model)
input_x_gradient_attr = input_x_gradient.attribute(input.to(device), target=preds.argmax())

In [None]:
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.imshow(img.numpy()[0], cmap = 'gray') #resized image
plt.subplot(1, 2, 2)
plt.imshow(np.abs(input_x_gradient_attr[0,0].detach().numpy()), cmap=plt.cm.hot)
plt.show()

# Integrated gradient-based saliency maps

In [None]:
input = img[None,...]
input.requires_grad_()

ig = IntegratedGradients(model)
ig_attr = ig.attribute(input.to(device), target=preds.argmax())

In [None]:
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.imshow(img.numpy()[0], cmap = 'gray') #resized image
plt.subplot(1, 2, 2)
plt.imshow(np.abs(ig_attr[0,0].detach().numpy()), cmap=plt.cm.hot)
plt.show()

# Gifsplanations

In [None]:
input = img[None,...]
input.requires_grad=False

ae = xrv.autoencoders.ResNetAE(weights="101-elastic").to(device)
target = "Nodule"
params = attribution.compute_attribution(input.to(device), "latentshift", model, target, ret_params=True, ae=ae)

dimgs = np.concatenate(params["dimgs"],1)[0]
fig, ax = plt.subplots(1,1, figsize=(8,3), dpi=350)
plt.imshow(np.concatenate(dimgs,1), interpolation='none', cmap="gray");
plt.axis('off');

In [None]:
%matplotlib inline
attribution.generate_video(input, model, target, ae, target_filename="test", border=False, show=True,
                           ffmpeg_path="ffmpeg")