# Revisiting Robust Interpretability of Self-Explaining Neural Networks
In this notebook, all results in the paper "Revisiting Robust Interpretability of Self-Explaining Neural Networks" by J. Goedhart, L. Jansen, H. Lim, and D. Nobbe, can be generated. Please read through the functionality of all cells before running.

# Training
The following three cells should be run for training of the models. Pre-trained models are included in the Github repository, so it is not necessary to train the models to generate the figures.
Note that training the models will reinitialize the weights, therefore slightly changing the figures generated.

In [None]:
# Only run this cell when the MNIST model needs to be trained for 5 concepts.
import torch
if torch.cuda.is_available():
    !python scripts/main_mnist.py --nconcepts 5 --train --cuda
else:
    !python scripts/main_mnist.py --nconcepts 5 --train 

In [None]:
# Only run this cell when the MNIST model needs to be trained for 22 concepts.
import torch
if torch.cuda.is_available():
    !python scripts/main_mnist.py --nconcepts 22 --train --cuda
else:
    !python scripts/main_mnist.py --nconcepts 22 --train 

In [None]:
# Only run this cell when the COMPAS model needs to be trained.
!python scripts/main_mnist.py --h_type input --train 

# All plots
Running the following three cells will generate faithfulness and dependency plots for every single datapoint in the test datasets. Only run this if you want to see other results than purely the ones we use in our paper. Note that results are saved to your hard disk (make sure to have ~10-20 GBs free), and interrupting the script *will not* delete already generated files.

In [None]:
# Running this script will save a faithfulness plot, a graph of concepts, and a box plot of the correlations
# for our method ('alternative'), and for the method of Alvarez-Melis and Jaakkola ('original'),
# for the MNIST model with 5 concepts.
!python scripts/main_mnist.py --nconcepts 5

In [None]:
# Running this script will save a faithfulness plot, a graph of concepts, and a box plot of the correlations
# for our method ('alternative'), and for the method of Alvarez-Melis and Jaakkola ('original'),
# for the COMPAS model.
!python scripts/main_mnist.py --nconcepts 22

In [None]:
# Running this script will save a faithfulness plot, a graph of concepts, and a box plot of the correlations
# for our method ('alternative'), and for the method of Alvarez-Melis and Jaakkola ('original'),
# for the MNIST model with 22 concepts.

!python scripts/main_compas.py --h_type input

# Paper figures
Run the following cells to generate and show the figures used in the paper and presentation.

In [None]:
# Checked
# This cell runs all results in the paper for MNIST with 5 concepts, with the exception of the faithfulness box plot. 
# Note that running this cell takes ~ 20 minutes.

from IPython.display import Image

!python scripts/main_mnist.py --nconcepts 5 --demo

concept_grid = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/concept_grid.png', retina=True)
dependency_plots = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/faithfulness0/dependencies/0.png', retina=True)
original = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/faithfulness0/_0/original.png', retina=True)
alternative = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/faithfulness0/_0/alternative.png', retina=True)
digit = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/faithfulness0/dependencies/0digit.png', retina=True)
stability_graph = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/stability.png', retina=True)
histogramthetaxhx = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/histogramthetaxhx.png', retina=True)
histogramthetax = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/histogramthetax.png', retina=True)
histogramhx = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/histogramhx.png', retina=True)

display(concept_grid, dependency_plots, original, alternative,
        digit, stability_graph, histogramthetaxhx, histogramthetax, histogramhx) 


In [None]:
# Checked.
# Run this cell to generate the faithfulness box plot that is in the paper for MNIST with 5 concepts. Takes about 10 minutes.

from IPython.display import Image

# This cell gathers all faithfulness-correlation numbers for the MNIST test set, for 5 concepts. 
# Note that it takes about 5-10 minutes to run
!python scripts/main_mnist.py --nconcepts 5 --noplot

faithfulness_box_plot = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts5_Reg1e-02_Sp0.0001_LR0.001/faithfulness_box_plot.png', retina=True)
display(faithfulness_box_plot)

In [None]:
# This cell runs all paper results for MNIST with 20 concepts, with the exception of the faithfulness box plot. 
# Note that running this cell takes ~ 5 minutes.
from IPython.display import Image

!python scripts/main_mnist.py --nconcepts 22 --demo

original = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts22_Reg1e-02_Sp0.0001_LR0.001/faithfulness0/_1/original.png', retina=True)
alternative = Image(filename='out/mnist/grad3_Hcnn_Thsimple_Cpts22_Reg1e-02_Sp0.0001_LR0.001/faithfulness0/_1/alternative.png', retina=True)

display(original, alternative) 


In [None]:
# This cell runs all results in the paper for COMPAS, with the exception of the faithfulness box plot. 
# Note that running this cell takes ~ 2 minutes.
from IPython.display import Image

!python scripts/main_compas.py --h_type input --demo

dependencies = Image(filename='out/compas/unreg_Hinput_Thsimple_Reg1e-02_LR0.001/faithfulness0/_0/dependencies/0.png', retina=True)
original = Image(filename='out/compas/unreg_Hinput_Thsimple_Reg1e-02_LR0.001/faithfulness0/_0/original.png', retina=True)
alternative = Image(filename='out/compas/unreg_Hinput_Thsimple_Reg1e-02_LR0.001/faithfulness0/_0/alternative.png', retina=True)
# histogramthetaxhx = Image(filename='out/compas/unreg_Hinput_Thsimple_Reg1e-02_LR0.001/histogramthetaxhx.png', retina=True)
# histogramthetax = Image(filename='out/compas/unreg_Hinput_Thsimple_Reg1e-02_LR0.001/histogramthetax.png', retina=True)
# histogramhx = Image(filename='out/compas/unreg_Hinput_Thsimple_Reg1e-02_LR0.001/histogramhx.png', retina=True)

display(dependencies, original, alternative, histogramthetaxhx, histogramthetax, histogramhx) 

In [None]:
# This cell generates the faithfulness correlation and stability box plots for the COMPAS dataset.
# Note that it takes about 1 minute to run
from IPython.display import Image


!python scripts/main_compas.py --h_type input --noplot

faithfulness_box_plot = Image(filename='out/compas/unreg_Hinput_Thsimple_Reg1e-02_LR0.001/faithfulness_box_plot.png', retina=True)
stability_box_plot = Image(filename='out/compas/unreg_Hinput_Thsimple_Reg1e-02_LR0.001/stability.png', retina=True)

display(faithfulness_box_plot, stability_box_plot)