In this demo file, we investigate methods for explainability in medical imaging using pre-trained TorchXRayVision models ( https://github.com/mlmed/torchxrayvision ).

# Setup

In [1]:
!pip install torchxrayvision captum gradio
!git clone https://github.com/mlmed/gifsplanation

from pathlib import Path
from typing import Callable

import gradio as gr
import numpy as np
import torchxrayvision as xrv
import skimage, torch, torchvision
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import sys,os
sys.path.insert(0,"./gifsplanation/")

from captum.attr import IntegratedGradients, Saliency, InputXGradient
import attribution

!wget https://raw.githubusercontent.com/mlmed/torchxrayvision/master/tests/16747_3_1.jpg #download test image

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchxrayvision
  Downloading torchxrayvision-0.0.38-py3-none-any.whl (29.0 MB)
[K     |████████████████████████████████| 29.0 MB 67.2 MB/s 
[?25hCollecting captum
  Downloading captum-0.5.0-py3-none-any.whl (1.4 MB)
[K     |████████████████████████████████| 1.4 MB 26.3 MB/s 
[?25hCollecting gradio
  Downloading gradio-3.4.0-py3-none-any.whl (5.3 MB)
[K     |████████████████████████████████| 5.3 MB 1.6 MB/s 
Collecting httpx
  Downloading httpx-0.23.0-py3-none-any.whl (84 kB)
[K     |████████████████████████████████| 84 kB 2.3 MB/s 
[?25hCollecting paramiko
  Downloading paramiko-2.11.0-py2.py3-none-any.whl (212 kB)
[K     |████████████████████████████████| 212 kB 49.1 MB/s 
[?25hCollecting pycryptodome
  Downloading pycryptodome-3.15.0-cp35-abi3-manylinux2010_x86_64.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 33.9 MB/s 
[?25hCollecting orjson
  Dow

In [2]:
def make_fig(plot_matrix):
    fig = plt.figure()
    plt.imshow(plot_matrix, cmap=plt.cm.hot)
    #plt.title(plot_title)
    ax = plt.gca()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    return fig

In [3]:
def xrv_prepare_image(image):
    img = xrv.datasets.normalize(image, 255) # convert 8-bit image to [-1024, 1024] range
    img = img.mean(2)[None, ...] # Make single color channel
    transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
    img = transform(img)
    img = torch.from_numpy(img)
    
    return img[None,...]

In [4]:
def predict(image, model_choice):
    """Function that serves predictions."""
    img = xrv_prepare_image(image)
    model = xrv.models.DenseNet(weights=model_choice)
    model.eval()

    outputs = model(img)
    scores =  outputs[0].detach().numpy().astype(np.float) #conversion to np.float is needed for visualization with gr.Label
    label = dict(zip(model.pathologies, scores) )
    return label

In [5]:
def explain_gradient(image, model_choice, target):
    """Function that serves explanations."""
    
    model = xrv.models.DenseNet(weights=model_choice)
    input = xrv_prepare_image(image)
    
    #Saliency
    saliency = Saliency(model)
    attr = saliency.attribute(input, target=model.pathologies.index(target))
    fig1 = make_fig( np.abs(attr[0,0].numpy()) )

    return fig1

def explain_input_x_gradient(image, model_choice, target):
    """Function that serves explanations."""
    
    model = xrv.models.DenseNet(weights=model_choice)
    input = xrv_prepare_image(image)
    
    #InputXGradient
    ixg = InputXGradient(model)
    attr = ixg.attribute(input, target=model.pathologies.index(target))
    fig2 = make_fig( np.abs(attr[0,0].detach().numpy()) ) 
    return fig2
    
def explain_integrated_gradients(image, model_choice, target):
    """Function that serves explanations."""
    
    model = xrv.models.DenseNet(weights=model_choice)
    input = xrv_prepare_image(image)
    
    #IntegratedGradients
    ig = IntegratedGradients(model)
    attr = ig.attribute(input, target=model.pathologies.index(target))
    fig3 = make_fig( np.abs(attr[0,0].detach().numpy()) )
    return fig3

def explain_gifsplanation(image, model_choice, target):
    """Function that serves explanations."""
    
    model = xrv.models.DenseNet(weights=model_choice)
    input = xrv_prepare_image(image)
    
    #Gifsplanation
    input.requires_grad=False
    ae = xrv.autoencoders.ResNetAE(weights="101-elastic")
    movie = attribution.generate_video(input, model, target, ae, target_filename="test", border=False, show=False,
                        ffmpeg_path="ffmpeg")

    return movie

In [6]:
model_choices = ["densenet121-res224-all", "densenet121-res224-rsna","densenet121-res224-nih","densenet121-res224-pc","densenet121-res224-chex","densenet121-res224-mimic_nb","densenet121-res224-mimic_ch"]

target_choices = ['Atelectasis','Consolidation','Infiltration','Pneumothorax','Edema','Emphysema',
                'Fibrosis','Effusion','Pneumonia','Pleural_Thickening','Cardiomegaly','Nodule','Mass',
                'Hernia','Lung Lesion','Fracture','Lung Opacity','Enlarged Cardiomediastinum']

In [23]:
# build a basic browser interface to a Python function
frontend = gr.Blocks()

with frontend:
  gr.Markdown(
    """
    # X-Ray Diagnosis
    Explore TorchXRayVision model predictions and data using this demo.
    """
    )
  #layout
  with gr.Tab("Prediction"):
    with gr.Row():
      with gr.Column():
        input_image = gr.Image(label="X-ray image")
        select_model = gr.Dropdown(label="Select model", choices=model_choices)
        with gr.Row():
          submit_button = gr.Button("Submit")
        gr.Examples(["16747_3_1.jpg"], inputs=input_image)

      with gr.Column():
        label = gr.Label(label="Multiclass predictions")

  submit_button.click(predict, [input_image, select_model], label) 

  with gr.Tab("Explanation"):

    with gr.Row():
      with gr.Column():
        select_target = gr.Dropdown(label="Select target", choices=target_choices)
        with gr.Row(equal_height=True):
          with gr.Column():
            with gr.Tab("Im..."):
              original_image = gr.Image(label='Original', interactive=False)
          with gr.Column():
            with gr.Tab("Gra..."):
              gradient_plot = gr.Plot(label="Gradient")
            with gr.Tab("XGr..."):
              input_x_gradient_plot = gr.Plot(label='InputXGradient')
            with gr.Tab("Int..."):
              integrated_gradients_plot = gr.Plot(label="IntegratedGradients")
            with gr.Tab("Gif..."):
              gifsplanation_vid = gr.Video(label="Gifsplanation")
          
  input_image.change(lambda s: s, inputs=input_image, outputs=original_image)
  select_target.change(explain_gradient, inputs=[input_image,select_model, select_target], outputs=gradient_plot)
  select_target.change(explain_input_x_gradient, inputs=[input_image,select_model, select_target], outputs=input_x_gradient_plot)
  select_target.change(explain_integrated_gradients, inputs=[input_image,select_model, select_target], outputs=integrated_gradients_plot)
  select_target.change(explain_gifsplanation, inputs=[input_image,select_model, select_target], outputs=gifsplanation_vid)

frontend.launch(share=True, show_error=True)

Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
Running on public URL: https://11397.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces: https://huggingface.co/spaces


(<gradio.routes.App at 0x7fe10e852850>,
 'http://127.0.0.1:7860/',
 'https://11397.gradio.app')

In [24]:
frontend.close()

Closing server running on port: 7860
