# Thermostat demo

In [1]:
import warnings
# Suppress warnings
warnings.filterwarnings('ignore')

In [2]:
!pip3 install --upgrade pip
!pip3 install cmake
!pip3 install cython
!pip3 install numpy
!pip3 install torch
!pip3 install datasets
!pip3 install spacy
!pip3 install transformers
!pip3 install overrides
!pip3 install jsonnet

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [3]:
import sys
# Include root directory in module path
sys.path.append('src')

from datasets import load_dataset
from pprint import pprint

import thermostat  # Accompanying library

# Load dataset

This will use the dataset script ("thermostat.py") in the "thermostat" directory.
In this example, we use the test set of the "imdb-bert-lgxa" configuration.
This refers to Layer Gradient x Activation (LGxA) explanations of the predictions by a BERT model that has been fine-tuned on the IMDb (train) dataset and evaluated on the IMDb test dataset.
In other words, we load the 25k test examples from the IMDb test plus the BERT predictions and the feature attributions from a Layer Gradient x Activation explainer.

In [4]:
data = load_dataset("thermostat", "imdb-bert-lgxa", split="test")

Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/imdb-bert-lgxa/1.0.0/82ada9d63d3c6b421a4ade89adc656b856fe9924abbc5cc94f20d472f5c71e99)


Now let's inspect a single instance of the loaded dataset.
Here, we will stick to the functionality that the datasets library already supplies us with.
For readability purposes, we will not print the whole content of that instance.
Instead, we're showing only the first few entries of the attributions and the input_ids, respectively.

In [5]:
instance = data[0]

print(f'Keys: {instance.keys()}\n')
print(f'Index: {instance["idx"]}')
print(f'Input IDs (first 15): {instance["input_ids"][:15]}')
print(f'Attributions (first 4): {instance["attributions"][:4]}')
print(f'True label: {instance["label"]}')
print(f'Predictions (logits): {instance["predictions"]}')

Keys: dict_keys(['attributions', 'idx', 'input_ids', 'label', 'predictions'])

Index: 0
Input IDs (first 15): [101, 2092, 1010, 1045, 7166, 2000, 3422, 3152, 2005, 2028, 1997, 2093, 4436, 1012, 6854]
Attributions (first 4): [-0.18760254979133606, -0.0315956249833107, 0.04854373633861542, 0.00658783596009016]
True label: 1
Predictions (logits): [-3.4371631145477295, 4.042327404022217]


# Visualize data
Can we make this a bit more readable?  
Of course! First, let's select a small subset using the datasets ".select" function:

In [6]:
lgxa_head = data.select(range(20))

Next, we can import the "Thermopack" class from our accompanying library. It inherits all properties from a Hugging Face Dataset, but also instantiates the tokenizer of the downstream model and automatically decodes the Input IDs to words.

In [7]:
tp = thermostat.Thermopack(lgxa_head)
print(tp)

IMDb dataset, BERT model, Layer Gradient x Activation explanations
Explainer: LayerGradientXActivation
Model: textattack/bert-base-uncased-imdb
Dataset: imdb



In [8]:
pprint({k: v for k, v in vars(tp).items() if not k.startswith('_')})

{'dataset': Dataset({
    features: ['attributions', 'idx', 'input_ids', 'label', 'predictions'],
    num_rows: 20
}),
 'dataset_name': 'imdb',
 'explainer_name': 'LayerGradientXActivation',
 'label_names': ['neg', 'pos'],
 'model_name': 'textattack/bert-base-uncased-imdb',
 'tokenizer': PreTrainedTokenizerFast(name_or_path='textattack/bert-base-uncased-imdb', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}),
 'units': [<thermostat.data.dataset_utils.Thermounit object at 0x7fa4872794a8>,
           <thermostat.data.dataset_utils.Thermounit object at 0x7fa4297a6128>,
           <thermostat.data.dataset_utils.Thermounit object at 0x7fa422d89b00>,
           <thermostat.data.dataset_utils.Thermounit object at 0x7fa4234e4e80>,
           <thermostat.data.dataset_utils.Thermounit object at 0x7fa4234faba8>,
           <thermostat.data.datase

In [9]:
tu0 = tp[0]
pprint({k: v for k, v in vars(tu0).items() if not k.startswith('_') and k not in ['heatmap', 'instance', 'tokens']})

{'dataset_name': 'imdb',
 'explainer_name': 'LayerGradientXActivation',
 'index': 0,
 'model_name': 'textattack/bert-base-uncased-imdb',
 'predicted_label': {'index': 1, 'name': 'pos'},
 'text': 'well, i tend to watch films for one of three reasons. unfortunately, '
         'there are no transformers in this film, so i can recommend it only '
         'on comedy value and pretty women ( read girls ) < br / > < br / > '
         'yes, it is funny, i know this due to the number of people in the '
         'cinema who were laughing on a regular basis throughout. personally '
         'though, i loved it for laura fraser, who imho is fit!',
 'tokenizer': PreTrainedTokenizerFast(name_or_path='textattack/bert-base-uncased-imdb', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}),
 'true_label': {'index': 1, 'name': 'pos'}}


In [10]:
heatmap = tu0.render(jupyter=True)
heatmap

# Aggregate data
Let us first compare the heatmaps of two different models on the same dataset+explainer configuration, MNLI + Occlusion.

In [12]:
xlnet = load_dataset("thermostat", "mnli-electra-occ", split="test")
xlnet_head = xlnet.select(range(20))
tp_xlnet = thermostat.Thermopack(xlnet_head)

bert = load_dataset("thermostat", "mnli-bert-occ", split="test")
bert_head = bert.select(range(20))
tp_bert = thermostat.Thermopack(bert_head)

Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/mnli-electra-occ/1.0.0/82ada9d63d3c6b421a4ade89adc656b856fe9924abbc5cc94f20d472f5c71e99)


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=916.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=301.0, style=ProgressStyle(description_…


Downloading and preparing dataset thermostat/mnli-bert-occ (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nfel/.cache/huggingface/datasets/thermostat/mnli-bert-occ/1.0.0/82ada9d63d3c6b421a4ade89adc656b856fe9924abbc5cc94f20d472f5c71e99...


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', layout=Layout(width='20px…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset thermostat downloaded and prepared to /home/nfel/.cache/huggingface/datasets/thermostat/mnli-bert-occ/1.0.0/82ada9d63d3c6b421a4ade89adc656b856fe9924abbc5cc94f20d472f5c71e99. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=630.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=48.0, style=ProgressStyle(description_w…


