# Thermostat demo
### This notebook is an introduction to Thermostat, to both the data hub and collection of explanation data (feature attribution maps) and the accompanying convenience functions for analysis of the maps.

To start off, we have to install the dependencies.

In [2]:
!pip3 install --upgrade pip
!pip3 install cmake
!pip3 install cython
!pip3 install numpy
!pip3 install torch
!pip3 install datasets
!pip3 install spacy
!pip3 install sentencepiece
!pip3 install transformers
!pip3 install overrides
!pip3 install jsonnet
!pip3 install sklearn
!pip3 install pandas

Defaulting to user installation because normal site-packages is not writeable
Collecting pip
  Downloading pip-21.1.3-py3-none-any.whl (1.5 MB)
[K     |████████████████████████████████| 1.5 MB 2.1 MB/s eta 0:00:01
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.2
    Uninstalling pip-21.1.2:
      Successfully uninstalled pip-21.1.2
Successfully installed pip-21.1.3
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


Next, we import some utilities.

In [3]:
import warnings
# Suppress warnings
warnings.filterwarnings('ignore')

import sys
# Include root directory in module path
sys.path.append('src')

from pprint import pprint

Now we can import our library.

In [4]:
import thermostat

# Load dataset

Let's use the `load` method which is a wrapper around the `load_dataset` function from HF `datasets`.
In the background, this uses the dataset script ("hf_dataset.py") in the "thermostat" directory.

In this example, we use the `imdb-bert-lig` configuration.
This refers to **Layer Integrated Gradients** (LIG) explanations of the predictions by a **BERT** model that has been fine-tuned on the **IMDb** (train) dataset and evaluated on the **IMDb** test dataset.
In other words, we load the 25k test examples from the IMDb test plus the BERT predictions and the feature attributions from a Layer Integrated Gradients explainer.

In [5]:
lig = thermostat.load("imdb-bert-lig")

Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/imdb-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)


Loading Thermostat configuration: imdb-bert-lig


Let's see what's inside the loaded dataset:

In [6]:
print(lig)

IMDb dataset, BERT model, Layer Integrated Gradients explanations
Dataset: imdb
Model: textattack/bert-base-uncased-imdb
Explainer: LayerIntegratedGradients



Now let's inspect a single instance of the loaded dataset.
For readability purposes, we will not print the whole content of that instance.
Instead, we're showing only the first few entries of the attributions and the input_ids, respectively.

In [7]:
example = lig[429]

print(f'Index: {example.idx}')
print(f'Attributions (first 4): {example.attributions[:4]}')
print(f'True label: {example.true_label}')
print(f'Predicted label: {example.predicted_label}')

Index: 429
Attributions (first 4): [0.0, 2.3141794204711914, 0.06655970215797424, -0.47832658886909485]
True label: pos
Predicted label: pos


We can also print a heatmap of token-attribution tuples via the `explanation` attribute of an instance.

In [8]:
pprint(example.explanation)

[('[CLS]', 0.0, 0),
 ('amazing', 2.3141794204711914, 1),
 ('movie', 0.06655970215797424, 2),
 ('.', -0.47832658886909485, 3),
 ('some', 0.15708176791667938, 4),
 ('of', -0.02931656688451767, 5),
 ('the', -0.08834744244813919, 6),
 ('script', -0.2660972774028778, 7),
 ('writing', -0.4021594822406769, 8),
 ('could', -0.19280624389648438, 9),
 ('have', -0.015477157197892666, 10),
 ('been', -0.21898044645786285, 11),
 ('better', -0.4095713794231415, 12),
 ('(', 0.05475223436951637, 13),
 ('some', 0.0466572567820549, 14),
 ('cl', 0.08523529022932053, 15),
 ('##iche', 0.05406142398715019, 16),
 ('##d', -0.031489163637161255, 17),
 ('language', -0.3399031162261963, 18),
 (')', -0.11275435984134674, 19),
 ('.', -0.22217823565006256, 20),
 ('joyce', 0.6259628534317017, 21),
 ("'", -0.20313552021980286, 22),
 ('s', -0.22971349954605103, 23),
 ('"', -0.28431516885757446, 24),
 ('the', 0.13832062482833862, 25),
 ('dead', -0.09080619364976883, 26),
 ('"', 0.008070609532296658, 27),
 ('is', -0.09763

Another option is printing it as a `pandas` DataFrame which is accessible via the `.heatmap` attribute.

In [9]:
import pandas as pd
pd.set_option('display.max_columns', None)

print(example.heatmap)

token_index     0        1          2         3         4          5   \
token        [CLS]  amazing      movie         .      some         of   
attribution      0        1  0.0287617 -0.206694  0.067878 -0.0126682   
text_field    text     text       text      text      text       text   

token_index         6         7         8          9           10         11  \
token              the    script   writing      could        have       been   
attribution -0.0381766 -0.114986 -0.173781 -0.0833152 -0.00668797 -0.0946255   
text_field        text      text      text       text        text       text   

token_index        12         13         14         15        18         19  \
token          better          (       some    cliched  language          )   
attribution -0.176983  0.0236595  0.0201615  0.0368318 -0.146878 -0.0487233   
text_field       text       text       text       text      text       text   

token_index         20       21         22         23        24      

# Visualize data
Now the much more visually pleasing way is to turn the attribution scores into colors and display the heatmap using the displaCy (spaCy) library. We can do this with the `.render()` function.

In [10]:
example.render()

# Aggregate data
Let us first compare the heatmaps of two different models on the same dataset+explainer configuration, MNLI + Occlusion.

In [11]:
bert = thermostat.load("multi_nli-bert-occ")
electra = thermostat.load("multi_nli-electra-occ")

Loading Thermostat configuration: multi_nli-bert-occ
Downloading and preparing dataset thermostat/multi_nli-bert-occ (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-occ/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset thermostat downloaded and prepared to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-occ/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f. Subsequent calls will reuse this data.
Loading Thermostat configuration: multi_nli-electra-occ
Downloading and preparing dataset thermostat/multi_nli-electra-occ (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-electra-occ/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset thermostat downloaded and prepared to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-electra-occ/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f. Subsequent calls will reuse this data.


Now let's build a list of instances for which the predicted labels of BERT and ELECTRA do not align:

In [12]:
disagreement = [(b, e) for (b, e) in zip(bert, electra) if b.predicted_label != e.predicted_label]

We choose one instance (index: 51) and render both heatmaps:

In [13]:
u_b, u_e = disagreement[51]
print(f'Instance index: {u_b.idx}')

print(f'Model: {u_b.model_name} | Pred: {u_b.predicted_label} | True: {u_b.true_label}')
u_b.render()
print(f'Model: {u_e.model_name} | Pred: {u_e.predicted_label} | True: {u_e.true_label}')
u_e.render()

Instance index: 378
Model: textattack/bert-base-uncased-MNLI | Pred: entailment | True: contradiction


Model: howey/electra-base-mnli | Pred: contradiction | True: contradiction


We observe that the Occlusion explainer does not attribute much importance to the phrase *can be lost in an instant*. This is plausible since the heatmap explains a misclassification: the maximum output activation stands for `entailment`, but the correct label is `contradiction` and the phrase certainly is a signal for `contradiction`. In contrast, in the case of ELECTRA which correctly classified the instance the signal phrase receives much higher importance scores.

### Bonus: Print classification report from sklearn
We also added the classification report function from sklearn as a method to apply to a Thermopack.

In [14]:
for model_name, data in zip(["bert", "electra"], [bert, electra]):
    print(model_name)
    data.classification_report()
    print('=====================\n\n')

bert
               precision    recall  f1-score   support

contradiction       0.86      0.87      0.86      3213
   entailment       0.89      0.85      0.87      3479
      neutral       0.79      0.82      0.81      3123

     accuracy                           0.85      9815
    macro avg       0.85      0.85      0.85      9815
 weighted avg       0.85      0.85      0.85      9815



electra
               precision    recall  f1-score   support

   entailment       0.93      0.87      0.90      3479
      neutral       0.83      0.87      0.85      3123
contradiction       0.91      0.92      0.91      3213

     accuracy                           0.89      9815
    macro avg       0.89      0.89      0.89      9815
 weighted avg       0.89      0.89      0.89      9815





# Explainer comparison

We can only compare the heatmaps between multiple explainers. For this, let us load both the MultiNLI-BERT-Occlusion dataset plus the associated LIG and LIME explainers.

In [15]:
unit_index = 378
u_occ = thermostat.load("multi_nli-bert-occ")[unit_index]
u_intg = thermostat.load("multi_nli-bert-lig")[unit_index]
u_lime = thermostat.load("multi_nli-bert-lime")[unit_index]

Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-occ/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)


Loading Thermostat configuration: multi_nli-bert-occ
Loading Thermostat configuration: multi_nli-bert-lig
Downloading and preparing dataset thermostat/multi_nli-bert-lig (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset thermostat downloaded and prepared to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f. Subsequent calls will reuse this data.
Loading Thermostat configuration: multi_nli-bert-lime
Downloading and preparing dataset thermostat/multi_nli-bert-lime (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lime/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset thermostat downloaded and prepared to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lime/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f. Subsequent calls will reuse this data.


In [16]:
print(f'{u_occ.explainer_name} --- same map as in previous example')
u_occ.render()
print(u_intg.explainer_name)
u_intg.render()
print(u_lime.explainer_name)
u_lime.render()

Occlusion --- same map as in previous example


LayerIntegratedGradients


LimeBase


# Rank correlation
A far more interesting, empirical investigation is how well two explainers on the same dataset+model combination correlate regarding their attribution scores.

In [10]:
imdb_lime = thermostat.load("imdb-bert-lime")
imdb_intg = thermostat.load("imdb-bert-lig")

Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/imdb-bert-lime/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)


Loading Thermostat configuration: imdb-bert-lime


Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/imdb-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)


Loading Thermostat configuration: imdb-bert-lig


After loading LIME and LIG explainers for IMDb+BERT, we import the SciPy function that calculates Kendall's tau for rank correlation of attributions.

We merge all attributions to a single list via `.flatten()` which is possible, because accessing the attributions of an entire dataset (Thermopack) via the `.attributions` attribute returns a NumPy array.

In [18]:
from scipy.stats import kendalltau
il_atts = imdb_lime.attributions.flatten()
ig_atts = imdb_intg.attributions.flatten()

kendall_imdb = kendalltau(il_atts, ig_atts)
print(kendall_imdb)

KendalltauResult(correlation=0.025657302000906455, pvalue=0.0)


Let's also consider MultiNLI explanations on BERT plus LIME and LIG.

In [19]:
mnli_lime = thermostat.load("multi_nli-bert-lime")
mnli_intg = thermostat.load("multi_nli-bert-lig")

Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lime/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)


Loading Thermostat configuration: multi_nli-bert-lime


Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)


Loading Thermostat configuration: multi_nli-bert-lig


In [20]:
ml_atts = mnli_lime.attributions.flatten()
mg_atts = mnli_intg.attributions.flatten()
kendall_mnli = kendalltau(ml_atts, mg_atts)
print(kendall_mnli)

KendalltauResult(correlation=0.10327941961925725, pvalue=0.0)


We find that the correlation between LIG and LIME is higher for MultiNLI than it is for IMDb.
This aligns well with the findings reported in the "Order in the Court" paper by Neely, Schouten et al. (2021): https://api.semanticscholar.org/CorpusID:234096057