In [1]:
%load_ext autoreload
%autoreload 2

# Explaining Keras text classifier predictions with Grad-CAM

We will explain text classification predicictions using Grad-CAM. We will use the IMDB dataset available at keras and the financial dataset, loading pretrained models.

Grad-CAM shows what's important in input, using a hidden layer and a target class.

First some imports

In [2]:
import os

import numpy as np
import pandas as pd
from IPython.display import display, HTML

# you may want to keep logging enabled when doing your own work
import logging
import tensorflow as tf
tf.get_logger().setLevel(logging.ERROR) # disable Tensorflow warnings for this tutorial
import warnings
warnings.simplefilter("ignore") # disable Keras warnings for this tutorial
import keras

import eli5

Using TensorFlow backend.


In [3]:
# we need this to load some of the local modules

old = os.getcwd()
os.chdir('..')

## Explaining sentiment classification

This is common in tutorials. A binary classification task with only one output. In this case high (1) is positive, low (0) is negative. We will use the IMDB dataset and a recurrent model, word level tokenization.

Load our model (available in ELI5).

In [4]:
model = keras.models.load_model('tests/estimators/keras_sentiment_classifier/keras_sentiment_classifier.h5')
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (None, None, 8)           80000     
_________________________________________________________________
masking_1 (Masking)          (None, None, 8)           0         
_________________________________________________________________
masking_2 (Masking)          (None, None, 8)           0         
_________________________________________________________________
masking_3 (Masking)          (None, None, 8)           0         
_________________________________________________________________
bidirectional_1 (Bidirection (None, None, 128)         37376     
_________________________________________________________________
bidirectional_2 (Bidirection (None, None, 64)          41216     
_________________________________________________________________
bidirectional_3 (Bidirection (None, 32)                10368     
__________

Load some sample data. We have a module that will do preprocessing, etc for us. Check the relevant package to learn more. For your own models you will have to do your own preprocessing

In [5]:
import tests.estimators.keras_sentiment_classifier.keras_sentiment_classifier \
as keras_sentiment_classifier

In [6]:
(x_train, y_train), (x_test, y_test) = keras_sentiment_classifier.prepare_train_test_dataset()

Confirming the accuracy of the model

In [7]:
print(model.metrics_names)
model.evaluate(x_test, y_test)

['loss', 'acc']


[0.4319177031707764, 0.81504]

Looks good? Let's go on and check one of the test samples.

In [8]:
doc = x_test[0:1]
print(doc)

tokens = keras_sentiment_classifier.vectorized_to_tokens(doc)
print(tokens)

[[   1  591  202   14   31    6  717   10   10    2    2    5    4  360
     7    4  177 5760  394  354    4  123    9 1035 1035 1035   10   10
    13   92  124   89  488 7944  100   28 1668   14   31   23   27 7479
    29  220  468    8  124   14  286  170    8  157   46    5   27  239
    16  179    2   38   32   25 7944  451  202   14    6  717    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0]]
[['<START>', 'please', 'give', 'this', 'one', 'a', 'miss', 'br', 'br', '<OOV>', '<OOV>', 'and', 'the', 'rest', 'of', 'the', 'cast', 'rendered', 'terrible', 'performances', 'the', 'show', 'is', 'flat', 'flat', 'flat', 'br', 'br', 'i', "don't", 'know', 'how', 'michael', 'madison', 'could', 'have', 'allowed', 'this', 'one', 'on', 'his', 'p

Check the prediction

In [9]:
model.predict(doc)

array([[0.1622659]], dtype=float32)

As expected, looks pretty low accuracy.

Now let's explain what got us this result with ELI5. We need to pass the model, the input, and the associated tokens that will be highlighted.

In [10]:
eli5.show_prediction(model, doc, tokens=tokens)

Let's try a custom input

In [11]:
s = "hello this is great but not so great"
doc_s, tokens_s = keras_sentiment_classifier.string_to_vectorized(s)
print(doc_s, tokens_s)

[[   1 4825   14    9   87   21   24   38   87]] [['<START>' 'hello' 'this' 'is' 'great' 'but' 'not' 'so' 'great']]


Notice that this model does not require fixed length input. We do not need to pad this sample.

In [12]:
model.predict(doc_s)

array([[0.5912496]], dtype=float32)

In [13]:
eli5.show_prediction(model, doc_s, tokens=tokens_s)

## The `counterfactual` and `relu` arguments

What did we see in the last section? Grad-CAM shows what makes a class score "go up". So we are only seeing the "positive" parts.

To "fix" this, we can pass two boolean arguments.

`counterfactual` shows the "opposite", what makes the score "go down" (set to `True` to enable).

In [22]:
eli5.show_prediction(model, doc_s, tokens=tokens_s, relu=False)

For the test sample

In [14]:
eli5.show_prediction(model, doc, tokens=tokens, counterfactual=True)

`relu` filters out the negative scores and only shows what makes the predicted score go up (set to `False` to disable).

In [15]:
eli5.show_prediction(model, doc, tokens=tokens, relu=False)

Green is positive, red is negative, white is neutral. We can see what made the network decide that is is a negative example.

What happens if we pass both `counterfactual` and `relu`?

In [16]:
eli5.show_prediction(model, doc, tokens=tokens, relu=False, counterfactual=True)

Notice how the colors (green and red) are inverted.

## Removing padding with `pad_value` and `padding` arguments

Often when working with text, each example is padded, whether because the model expects input with a certain length, or to have all samples be the same length to put them in a batch.

We can remove padding by specifying two arguments. The first is `pad_value`, the padding token such as `<PAD>` or a numeric value such as `0` for `doc`. The second argument is `padding`, which should be set to either `pre` (padding is done before actual text) or `post` (padding is done after actual text).

In [17]:
eli5.show_prediction(model, doc, tokens=tokens, relu=False, pad_value='<PAD>', padding='post')

Now the explanation is shorter. This is useful if the input has a lot of padding.

## Choosing a hidden layer to do Grad-CAM on

Grad-CAM requires a hidden layer to do its calculations on. This is controlled by the `layer` argument. We can pass the layer (as an int index, string name, or a keras Layer instance) explicitly, or let ELI5 attempt to find a good layer to do Grad-CAM on automatically.

In [18]:
for layer in model.layers:
    name = layer.name
    print(name)
    if 'masking' not in layer.name:
        e = eli5.show_prediction(model,
                                 doc,
                                 tokens=tokens,
                                 layer=layer,
                                 relu=False, 
                                 pad_value='<PAD>', 
                                 padding='post')
        display(e) # if using in a loop, we need these two explicit IPython calls

embedding_1


masking_1
masking_2
masking_3
bidirectional_1


bidirectional_2


bidirectional_3


dense_1


dense_2


If you don't get good explanations from ELI5 out of the box, it may be worth looking into this parameter. We advice to pick layers that contain "spatial or temporal" information, i.e. NOT dense/fully-connected or merge layers.

Notice that when explaining the final dense layer node (there is only 1 output), we get an "all green" explanation. You need to hover over the explanation to see the actual value. It seems off because there are no "negative" values here and the colouring is not gradual.

## Explaining multiple classes

A multi-class model trained on the finanial dataset. Character-level tokenization. Convolutional network.

In [19]:
# multiclass model (*target, layer - conv/others, diff. types of expls, padding and its effect)

In [20]:
model2 = keras.models.load_model('tests/estimators/keras_multiclass_text_classifier/keras_multiclass_text_classifier.h5')
model2.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (None, 3193, 8)           816       
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 3179, 128)         15488     
_________________________________________________________________
dropout_1 (Dropout)          (None, 3179, 128)         0         
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 1589, 128)         0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 1580, 128)         163968    
_________________________________________________________________
dropout_2 (Dropout)          (None, 1580, 128)         0         
_________________________________________________________________
average_pooling1d_1 (Average (None, 790, 128)          0         
__________

In [24]:
import tests.estimators.keras_multiclass_text_classifier.keras_multiclass_text_classifier \
as keras_multiclass_text_classifier

In [25]:
(x_train, x_test), (y_train, y_test) = keras_multiclass_text_classifier.prepare_train_test_dataset()

Possible classes

In [35]:
keras_multiclass_text_classifier.labels_index

{'Debt collection': 0,
 'Consumer Loan': 1,
 'Mortgage': 2,
 'Credit card': 3,
 'Credit reporting': 4,
 'Student loan': 5,
 'Bank account or service': 6,
 'Payday loan': 7,
 'Money transfers': 8,
 'Other financial service': 9,
 'Prepaid card': 10}

Again check the metrics.

In [26]:
print(model2.metrics_names)
model2.evaluate(x_test, y_test)

['loss', 'acc']


[0.6319513120651246, 0.7999999990463257]

Let's explain one of the test samples

In [72]:
doc = x_test[0:1]
tokens = keras_multiclass_text_classifier.vectorized_to_tokens(doc)
s = keras_multiclass_text_classifier.tokens_to_string(tokens)

print(len(doc[0]))
limit = 150
print(doc[0, :limit])
print(tokens[0, :limit])
print(s[0][:limit+800])

3193
[38 15 21  3  7  2 20  8  7  5  7 15  8  5 14  2 11  3  9 25  8 15  3 11
  2 15 14  5  8 16 11  2 11  8 16 17 14  4  5  7  3  6 17 11 14 18  2  4
  6  2 12  5 25  3  2  5  2 14  6  5  7  2 21  8  4 12  2 16  3  2 58  2
 13  3 11 19  8  4  3  2 16 18  2  7  3 25  3  9  2 12  5 25  8  7 22  2
 13  6  7  3  2 24 17 11  8  7  3 11 11  2 21  8  4 12  2  4 12  3 16  2
  6  9  2 12  5 25  8  7 22  2 24  3  3  7  2  7  6  4  8 20  8  3 13  2
  6 20  2 11  5  8]
['O' 'c' 'w' 'e' 'n' ' ' 'f' 'i' 'n' 'a' 'n' 'c' 'i' 'a' 'l' ' ' 's' 'e'
 'r' 'v' 'i' 'c' 'e' 's' ' ' 'c' 'l' 'a' 'i' 'm' 's' ' ' 's' 'i' 'm' 'u'
 'l' 't' 'a' 'n' 'e' 'o' 'u' 's' 'l' 'y' ' ' 't' 'o' ' ' 'h' 'a' 'v' 'e'
 ' ' 'a' ' ' 'l' 'o' 'a' 'n' ' ' 'w' 'i' 't' 'h' ' ' 'm' 'e' ' ' '(' ' '
 'd' 'e' 's' 'p' 'i' 't' 'e' ' ' 'm' 'y' ' ' 'n' 'e' 'v' 'e' 'r' ' ' 'h'
 'a' 'v' 'i' 'n' 'g' ' ' 'd' 'o' 'n' 'e' ' ' 'b' 'u' 's' 'i' 'n' 'e' 's'
 's' ' ' 'w' 'i' 't' 'h' ' ' 't' 'h' 'e' 'm' ' ' 'o' 'r' ' ' 'h' 'a' 'v'
 'i' 'n' 'g' ' ' 'b' 'e' '

Notice that the padding length is quite long. We are also dealing with character-level tokenization - our tokens are single characters, not words.

Let's check what the model predicts (to which category the financial complaint belongs).

In [28]:
preds = model2.predict(doc)
print(preds)
y = np.argmax(preds)
print(y)
keras_multiclass_text_classifier.decode_output(y)

[[7.4966592e-03 9.7562626e-08 9.9250317e-01 9.1982411e-12 5.3569739e-08
  4.8417964e-10 9.6964792e-10 4.0114050e-09 5.9291594e-10 3.4063903e-13
  3.9474773e-19]]
2


'Mortgage'

And the ground truth:

In [29]:
y_truth = y_test[0]
print(y_truth)
keras_multiclass_text_classifier.decode_output(y_truth)

[0 0 1 0 0 0 0 0 0 0 0]


'Mortgage'

Now let's explain this prediction with ELI5. Enable relu to not see other classes.

In [132]:
eli5.show_prediction(model2, doc, tokens=tokens, pad_value='<PAD>', padding='post',
                    layer=3, 
                    )
# FIXME: layer choice + sensible explanation

Our own example

In [133]:
s = "the IRS is afterr my car loan"
doc_s, tokens_s = keras_multiclass_text_classifier.string_to_vectorized(s)
print(doc_s)
print(tokens_s[0, :50]) # note that this model requires fixed length input

[[ 4 12  3 ...  0  0  0]]
['t' 'h' 'e' ' ' 'I' 'R' 'S' ' ' 'i' 's' ' ' 'a' 'f' 't' 'e' 'r' 'r' ' '
 'm' 'y' ' ' 'c' 'a' 'r' ' ' 'l' 'o' 'a' 'n' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']


In [134]:
preds = model2.predict(doc_s)
print(preds)
keras_multiclass_text_classifier.decode_output(preds)

[[0.09576575 0.27872923 0.10852851 0.03327851 0.11653358 0.1867436
  0.02678595 0.13854526 0.00900717 0.00178243 0.00429991]]


'Consumer Loan'

In [135]:
eli5.show_prediction(model2, doc_s, tokens=tokens_s, pad_value='<PAD>', padding='post',
                    layer=3,
                    )
# FIXME: is this explanation sensible?

# TODO: would be good to show predicted label

## Choosing a classification target to focus on

In [136]:
debt_idx = 0
loan_idx = 1

In [137]:
eli5.show_prediction(model2, doc_s, tokens=tokens_s, pad_value='<PAD>', padding='post',
                    layer=3,
                    targets=[debt_idx],
                    )

Sensible?

## How it works - `explain_prediction` and `format_as_html`.

In [32]:
# heatmap, tokens, weighted_spans, interpolation_kind, etc.

In [104]:
E = eli5.explain_prediction(model2, doc_s, tokens=tokens_s, pad_value='<PAD>', padding='post', layer=3)

Looking at the `Explanation` object

In [96]:
repr(E)

"Explanation(estimator='sequential_1', description='\\nGrad-CAM visualization for classification tasks; \\noutput is explanation object that contains a heatmap.\\n', error='', method='Grad-CAM', is_regression=False, targets=[TargetExplanation(target=1, feature_weights=None, proba=None, score=0.26196888, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='the IRS is after my car loan', spans=[('t', [(0, 1)], 2.465094823378422e-06), ('h', [(1, 2)], 0.0), ('e', [(2, 3)], 0.0), (' ', [(3, 4)], 0.0), ('I', [(4, 5)], 0.0), ('R', [(5, 6)], 3.983062561019324e-05), ('S', [(6, 7)], 0.0), (' ', [(7, 8)], 0.0), ('i', [(8, 9)], 0.0), ('s', [(9, 10)], 0.0), (' ', [(10, 11)], 0.0), ('a', [(11, 12)], 0.0), ('f', [(12, 13)], 5.141518772688869e-05), ('t', [(13, 14)], 2.465094823378422e-06), ('e', [(14, 15)], 0.0), ('r', [(15, 16)], 0.00021527814971022963), (' ', [(16, 17)], 0.0), ('m', [(17, 18)], 0.0), ('y', [(18, 19)], 0.0), (' ', [(19, 20)], 0.0), ('c', [(20, 21)], 0.0), ('a'

We can get the predicted class and the value for the prediction

In [97]:
target = E.targets[0]
print(target.target, target.score)

1 0.26196888


The highlighting for each token is stored in a `WeightedSpans` object (specifically the `DocWeightedSpans` object)

In [130]:
weighted_spans = target.weighted_spans
print(weighted_spans)

doc_ws = weighted_spans.docs_weighted_spans[0]
print(doc_ws)

WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='the IRS is after my car loan', spans=[('t', [(0, 1)], 2.465094823378422e-06), ('h', [(1, 2)], 0.0), ('e', [(2, 3)], 0.0), (' ', [(3, 4)], 0.0), ('I', [(4, 5)], 0.0), ('R', [(5, 6)], 3.983062561019324e-05), ('S', [(6, 7)], 0.0), (' ', [(7, 8)], 0.0), ('i', [(8, 9)], 0.0), ('s', [(9, 10)], 0.0), (' ', [(10, 11)], 0.0), ('a', [(11, 12)], 0.0), ('f', [(12, 13)], 5.141518772688869e-05), ('t', [(13, 14)], 2.465094823378422e-06), ('e', [(14, 15)], 0.0), ('r', [(15, 16)], 0.00021527814971022963), (' ', [(16, 17)], 0.0), ('m', [(17, 18)], 0.0), ('y', [(18, 19)], 0.0), (' ', [(19, 20)], 0.0), ('c', [(20, 21)], 0.0), ('a', [(21, 22)], 0.0), ('r', [(22, 23)], 0.00021527814971022963), (' ', [(23, 24)], 0.0), ('l', [(24, 25)], 0.0), ('o', [(25, 26)], 0.0), ('a', [(26, 27)], 0.0), ('n', [(27, 28)], 0.0)], preserve_density=None, vec_name=None)], other=None)
DocWeightedSpans(document='the IRS is after my car loan', spans=[('t', [(0, 1)], 2.4

Observe the `document` attribute and `spans`

In [101]:
print(doc_ws.document)
print(doc_ws.spans)

the IRS is after my car loan
[('t', [(0, 1)], 2.465094823378422e-06), ('h', [(1, 2)], 0.0), ('e', [(2, 3)], 0.0), (' ', [(3, 4)], 0.0), ('I', [(4, 5)], 0.0), ('R', [(5, 6)], 3.983062561019324e-05), ('S', [(6, 7)], 0.0), (' ', [(7, 8)], 0.0), ('i', [(8, 9)], 0.0), ('s', [(9, 10)], 0.0), (' ', [(10, 11)], 0.0), ('a', [(11, 12)], 0.0), ('f', [(12, 13)], 5.141518772688869e-05), ('t', [(13, 14)], 2.465094823378422e-06), ('e', [(14, 15)], 0.0), ('r', [(15, 16)], 0.00021527814971022963), (' ', [(16, 17)], 0.0), ('m', [(17, 18)], 0.0), ('y', [(18, 19)], 0.0), (' ', [(19, 20)], 0.0), ('c', [(20, 21)], 0.0), ('a', [(21, 22)], 0.0), ('r', [(22, 23)], 0.00021527814971022963), (' ', [(23, 24)], 0.0), ('l', [(24, 25)], 0.0), ('o', [(25, 26)], 0.0), ('a', [(26, 27)], 0.0), ('n', [(27, 28)], 0.0)]


The `document` is the "stringified" version of `tokens`. If you have a custom "tokens -> string" algorithm you may want to set this attribute yourself.

The `spans` object is a list of weights for each character in `document`. We use the indices in `document` string to indicate which characters should be weighted with a specific value.

The weights come from the `heatmap` object found on each item in `targets`.

In [131]:
heatmap = target.heatmap
print(heatmap)
print(len(heatmap))

print(len(doc_ws.spans))

[2.46509482e-06 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 3.98306256e-05 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 5.14151877e-05 2.46509482e-06 0.00000000e+00 2.15278150e-04
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 2.15278150e-04 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
28
28


You can think of this as an array of "importances" in the tokens array (after padding is removed).

Let's format this. HTML formatter is what should be used here.

In [119]:
import eli5.formatters.fields as fields
F = eli5.format_as_html(E, show=fields.WEIGHTS)

We pass a `show` argument to not display the method name or its description (Grad-CAM). See `eli5.format_as_html()` for a list of all supported arguments.

The output is an HTML-encoded string.

In [109]:
repr(F)

'\'\\n    <style>\\n    table.eli5-weights tr:hover {\\n        filter: brightness(85%);\\n    }\\n</style>\\n\\n\\n\\n    \\n        <p>Explained as: Grad-CAM</p>\\n    \\n\\n    \\n\\n    \\n\\n    \\n\\n    \\n\\n    \\n\\n\\n    \\n\\n    \\n        \\n        <pre>\\nGrad-CAM visualization for classification tasks; \\noutput is explanation object that contains a heatmap.\\n</pre>\\n    \\n\\n    \\n\\n    \\n\\n    \\n\\n    \\n\\n\\n    \\n\\n    \\n\\n    \\n\\n    \\n\\n    \\n\\n    \\n\\n\\n    \\n\\n    \\n\\n    \\n\\n    \\n        \\n\\n    \\n\\n        \\n            \\n                \\n                \\n            \\n        \\n\\n        \\n\\n\\n    <p style="margin-bottom: 2.5em; margin-top:-0.5em;">\\n        <span style="background-color: hsl(120, 100.00%, 94.19%); opacity: 0.81" title="0.000">t</span><span style="background-color: hsl(120, 100.00%, 94.17%); opacity: 0.81" title="0.000">h</span><span style="background-color: hsl(120, 100.00%, 99.86%); opacity:

Display it in an IPython notebook

In [117]:
display(HTML(F))

## The `interpolation_kind` argument

Heatmap does not match shape of tokens. We want to control how the resizing is done.

Getting back to sentiment classification

In [129]:
print(tokens.shape, len(heatmap))

(1, 3193) 28


In [106]:
model2.get_layer(index=3).output_shape

(None, 1589, 128)

In [122]:
kinds = ['linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'previous', 'next']

In [126]:
for kind in kinds:
    print(kind)
    H = eli5.show_prediction(model2, doc_s, tokens=tokens_s, pad_value='<PAD>', padding='post',
                        layer=3,
                        interpolation_kind=kind,
                        )
    display(H)

linear


nearest


zero


slinear


quadratic


cubic


previous


next


The results are roughly the same. If highlighting seems off this argument may be a thing to try.

## Notes on results

### Multi-label classification

Does not work