# Thermostat demo
### This notebook is an introduction to Thermostat, to both the data hub and collection of explanation data (feature attribution maps) and the accompanying convenience functions for analysis of the maps.

To start off, we have to install the dependencies.

In [2]:
!pip3 install --upgrade pip
!pip3 install cmake
!pip3 install cython
!pip3 install numpy
!pip3 install torch
!pip3 install datasets
!pip3 install spacy
!pip3 install sentencepiece
!pip3 install transformers
!pip3 install overrides
!pip3 install jsonnet
!pip3 install sklearn
!pip3 install pandas

Defaulting to user installation because normal site-packages is not writeable
Collecting pip
  Downloading pip-21.1.3-py3-none-any.whl (1.5 MB)
[K     |████████████████████████████████| 1.5 MB 2.1 MB/s eta 0:00:01
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.2
    Uninstalling pip-21.1.2:
      Successfully uninstalled pip-21.1.2
Successfully installed pip-21.1.3
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


Next, we import some utilities.

In [1]:
import warnings
# Suppress warnings
warnings.filterwarnings('ignore')

import sys
# Include root directory in module path
sys.path.append('src')

from pprint import pprint

Now we can import our library.

In [2]:
import thermostat

# Load dataset

Let's use the `load` method which is a wrapper around the `load_dataset` function from HF `datasets`.
In the background, this uses the dataset script ("hf_dataset.py") in the "thermostat" directory.

In this example, we use the `imdb-bert-lig` configuration.
This refers to **Layer Integrated Gradients** (LIG) explanations of the predictions by a **BERT** model that has been fine-tuned on the **IMDb** (train) dataset and evaluated on the **IMDb** test dataset.
In other words, we load the 25k test examples from the IMDb test plus the BERT predictions and the feature attributions from a Layer Integrated Gradients explainer.

In [3]:
lig = thermostat.load("ag_news-bert-lime-100")

Loading Thermostat configuration: ag_news-bert-lime-100
Downloading and preparing dataset thermostat/ag_news-bert-lime-100 to C:\Users\49176\.cache\huggingface\datasets\thermostat\ag_news-bert-lime-100\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...


Downloading:   0%|          | 0.00/48.2M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset thermostat downloaded and prepared to C:\Users\49176\.cache\huggingface\datasets\thermostat\ag_news-bert-lime-100\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b. Subsequent calls will reuse this data.


Let's see what's inside the loaded dataset:

In [4]:
print(lig)

AG News dataset, BERT model, LIME explanations, 100 samples
Dataset: ag_news
Model: textattack/bert-base-uncased-ag-news
Explainer: LimeBase



Now let's inspect a single instance of the loaded dataset.
For readability purposes, we will not print the whole content of that instance.
Instead, we're showing only the first few entries of the attributions and the input_ids, respectively.

In [5]:
example = lig[429]

print(f'Index: {example.idx}')
print(f'Attributions (first 4): {example.attributions[:4]}')
print(f'True label: {example.true_label}')
print(f'Predicted label: {example.predicted_label}')

Index: 429
Attributions (first 4): [0.0, -2.546515861467924e-05, 0.00011155286483699456, 0.0002800458169076592]
True label: Sci/Tech
Predicted label: Sci/Tech


We can also print a heatmap of token-attribution tuples via the `explanation` attribute of an instance.

In [6]:
pprint(example.explanation)

[('[CLS]', 0.0, 0),
 ('stunt', -2.546515861467924e-05, 1),
 ('pilots', 0.00011155286483699456, 2),
 ('to', 0.0002800458169076592, 3),
 ('s', 0.0005430751480162144, 4),
 ('##na', 0.00019418592273723334, 5),
 ('##g', -0.00031486275838688016, 6),
 ('a', -0.0006967579247429967, 7),
 ('falling', 3.421621613597381e-06, 8),
 ('nasa', 0.0003259600780438632, 9),
 ('craft', 0.00020356148888822645, 10),
 ('nasa', 0.00030730877188034356, 11),
 ('#', 0.0011188226053491235, 12),
 ('39', 0.00013333074457477778, 13),
 (';', 4.2671450501075014e-05, 14),
 ('s', 0.0009039245778694749, 15),
 ('three', -3.066950375796296e-05, 16),
 ('-', -0.00011021247337339446, 17),
 ('year', 0.0002888077578973025, 18),
 ('effort', -6.449802458519116e-05, 19),
 ('to', -5.040947507950477e-05, 20),
 ('bring', -0.00012397475074976683, 21),
 ('some', 0.00015303328109439462, 22),
 ('genuine', 1.9381795937079005e-05, 23),
 ('star', -3.429375965424697e-06, 24),
 ('dust', 5.8175814046990126e-05, 25),
 ('back', 0.00014667658251710

Another option is printing it as a `pandas` DataFrame which is accessible via the `.heatmap` attribute.

In [7]:
import pandas as pd
pd.set_option('display.max_columns', None)

print(example.heatmap)

token_index     0         1         2         3         4        7         8   \
token        [CLS]     stunt    pilots        to      snag        a   falling   
attribution    0.0 -0.022761  0.099706  0.250304  0.485399 -0.62276  0.003058   
text_field    text      text      text      text      text     text      text   

token_index        9         10        11    12        13       14        15  \
token            nasa     craft      nasa     #        39        ;         s   
attribution  0.291342  0.181943  0.274672   1.0  0.119171  0.03814  0.807925   
text_field       text      text      text  text      text     text      text   

token_index        16        17        18        19        20        21  \
token           three         -      year    effort        to     bring   
attribution -0.027412 -0.098508  0.258135 -0.057648 -0.045056 -0.110808   
text_field       text      text      text      text      text      text   

token_index        22        23        24        25  

# Visualize data
Now the much more visually pleasing way is to turn the attribution scores into colors and display the heatmap using the displaCy (spaCy) library. We can do this with the `.render()` function.

In [8]:
example.render()

# Aggregate data
Let us first compare the heatmaps of two different models on the same dataset+explainer configuration, MNLI + Occlusion.

In [9]:
bert = thermostat.load("ag_news-bert-lime-100")
electra = thermostat.load("ag_news-albert-lime-100")

Reusing dataset thermostat (C:\Users\49176\.cache\huggingface\datasets\thermostat\ag_news-bert-lime-100\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b)


Loading Thermostat configuration: ag_news-bert-lime-100
Loading Thermostat configuration: ag_news-albert-lime-100
Downloading and preparing dataset thermostat/ag_news-albert-lime-100 to C:\Users\49176\.cache\huggingface\datasets\thermostat\ag_news-albert-lime-100\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...


Downloading:   0%|          | 0.00/48.3M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset thermostat downloaded and prepared to C:\Users\49176\.cache\huggingface\datasets\thermostat\ag_news-albert-lime-100\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b. Subsequent calls will reuse this data.


Now let's build a list of instances for which the predicted labels of BERT and ELECTRA do not align:

In [10]:
disagreement = [(b, e) for (b, e) in zip(bert, electra) if b.predicted_label != e.predicted_label]

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/922 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/742k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/156 [00:00<?, ?B/s]

We choose one instance (index: 51) and render both heatmaps:

In [11]:
u_b, u_e = disagreement[51]
print(f'Instance index: {u_b.idx}')

print(f'Model: {u_b.model_name} | Pred: {u_b.predicted_label} | True: {u_b.true_label}')
u_b.render()
print(f'Model: {u_e.model_name} | Pred: {u_e.predicted_label} | True: {u_e.true_label}')
u_e.render()

Instance index: 1513
Model: textattack/bert-base-uncased-ag-news | Pred: Sci/Tech | True: Sci/Tech


Model: textattack/albert-base-v2-ag-news | Pred: Business | True: Sci/Tech


We observe that the Occlusion explainer does not attribute much importance to the phrase *can be lost in an instant*. This is plausible since the heatmap explains a misclassification: the maximum output activation stands for `entailment`, but the correct label is `contradiction` and the phrase certainly is a signal for `contradiction`. In contrast, in the case of ELECTRA which correctly classified the instance the signal phrase receives much higher importance scores.

### Bonus: Print classification report from sklearn
We also added the classification report function from sklearn as a method to apply to a Thermopack.

In [13]:
for model_name, data in zip(["bert", "electra"], [bert, electra]):
    print(model_name)
    data.classification_report()
    print('=====================\n\n')

bert
               precision    recall  f1-score   support

contradiction       0.86      0.87      0.86      3213
   entailment       0.89      0.85      0.87      3479
      neutral       0.79      0.82      0.81      3123

     accuracy                           0.85      9815
    macro avg       0.85      0.85      0.85      9815
 weighted avg       0.85      0.85      0.85      9815



electra
               precision    recall  f1-score   support

   entailment       0.93      0.87      0.90      3479
      neutral       0.83      0.87      0.85      3123
contradiction       0.91      0.92      0.91      3213

     accuracy                           0.89      9815
    macro avg       0.89      0.89      0.89      9815
 weighted avg       0.89      0.89      0.89      9815





# Explainer comparison

We can only compare the heatmaps between multiple explainers. For this, let us load both the MultiNLI-BERT-Occlusion dataset plus the associated LIG and LIME explainers.

In [14]:
unit_index = 378
u_occ = thermostat.load("multi_nli-bert-occ")[unit_index]
u_intg = thermostat.load("multi_nli-bert-lig")[unit_index]
u_lime = thermostat.load("multi_nli-bert-lime")[unit_index]

Reusing dataset thermostat (C:\Users\49176\.cache\huggingface\datasets\thermostat\multi_nli-bert-occ\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b)


Loading Thermostat configuration: multi_nli-bert-occ
Dataset path is D:\Working Student\repo\thermostat\src\thermostat\dataset.py
Additional parameters for loading: {}
Loading Thermostat configuration: multi_nli-bert-lig
Dataset path is D:\Working Student\repo\thermostat\src\thermostat\dataset.py
Additional parameters for loading: {}
Downloading and preparing dataset thermostat/multi_nli-bert-lig to C:\Users\49176\.cache\huggingface\datasets\thermostat\multi_nli-bert-lig\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...


Downloading:   0%|          | 0.00/58.5M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset thermostat downloaded and prepared to C:\Users\49176\.cache\huggingface\datasets\thermostat\multi_nli-bert-lig\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b. Subsequent calls will reuse this data.
Loading Thermostat configuration: multi_nli-bert-lime
Dataset path is D:\Working Student\repo\thermostat\src\thermostat\dataset.py
Additional parameters for loading: {}
Downloading and preparing dataset thermostat/multi_nli-bert-lime to C:\Users\49176\.cache\huggingface\datasets\thermostat\multi_nli-bert-lime\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...


Downloading:   0%|          | 0.00/59.4M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset thermostat downloaded and prepared to C:\Users\49176\.cache\huggingface\datasets\thermostat\multi_nli-bert-lime\1.0.1\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b. Subsequent calls will reuse this data.


In [15]:
print(f'{u_occ.explainer_name} --- same map as in previous example')
u_occ.render()
print(u_intg.explainer_name)
u_intg.render()
print(u_lime.explainer_name)
u_lime.render()

Occlusion --- same map as in previous example


LayerIntegratedGradients


LimeBase


# Rank correlation
A far more interesting, empirical investigation is how well two explainers on the same dataset+model combination correlate regarding their attribution scores.

In [None]:
imdb_lime = thermostat.load("imdb-bert-lime")
imdb_intg = thermostat.load("imdb-bert-lig")

After loading LIME and LIG explainers for IMDb+BERT, we import the SciPy function that calculates Kendall's tau for rank correlation of attributions.

We merge all attributions to a single list via `.flatten()` which is possible, because accessing the attributions of an entire dataset (Thermopack) via the `.attributions` attribute returns a NumPy array.

In [18]:
from scipy.stats import kendalltau
il_atts = imdb_lime.attributions.flatten()
ig_atts = imdb_intg.attributions.flatten()

kendall_imdb = kendalltau(il_atts, ig_atts)
print(kendall_imdb)

KendalltauResult(correlation=0.025657302000906455, pvalue=0.0)


Let's also consider MultiNLI explanations on BERT plus LIME and LIG.

In [19]:
mnli_lime = thermostat.load("multi_nli-bert-lime")
mnli_intg = thermostat.load("multi_nli-bert-lig")

Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lime/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)


Loading Thermostat configuration: multi_nli-bert-lime


Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)


Loading Thermostat configuration: multi_nli-bert-lig


In [20]:
ml_atts = mnli_lime.attributions.flatten()
mg_atts = mnli_intg.attributions.flatten()
kendall_mnli = kendalltau(ml_atts, mg_atts)
print(kendall_mnli)

KendalltauResult(correlation=0.10327941961925725, pvalue=0.0)


We find that the correlation between LIG and LIME is higher for MultiNLI than it is for IMDb.
This aligns well with the findings reported in the "Order in the Court" paper by Neely, Schouten et al. (2021): https://api.semanticscholar.org/CorpusID:234096057