Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Keras images) Add an optional image argument, and other improvements #329

Merged
merged 21 commits into from Aug 10, 2019
Merged
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c4eb039
(keras) Make image argument required
teabolt Aug 5, 2019
1111105
Update README.md to include keras
teabolt Aug 6, 2019
7d6d982
Merge branch 'master' of https://github.com/teabolt/eli5 into keras-g…
teabolt Aug 6, 2019
14452c3
Remove mentions of target_names (not implemented)
teabolt Aug 7, 2019
9e85021
Add dispatch function and image implementation
teabolt Aug 7, 2019
9a0cd53
Update dispatcher and image function docstrings
teabolt Aug 7, 2019
6b002a6
Automatically check if model/input is image-based. Convert input to a…
teabolt Aug 7, 2019
7d82c30
Mock keras.preprocessing.image in docs conf (CI fix)
teabolt Aug 7, 2019
d1af643
Update tests, docs, tutorial with image argument changes
teabolt Aug 7, 2019
6aec486
Blank line between header and list in docstring (CI fix)
teabolt Aug 8, 2019
aaa83a7
Test keras not supported function
teabolt Aug 8, 2019
47b03c1
Docstring typo
teabolt Aug 8, 2019
df25038
Clarify "not supported" error message.
teabolt Aug 8, 2019
6234aaa
Remove TODO: explain Grad-CAM in docstring. (Will be explained in ker…
teabolt Aug 8, 2019
13e1847
Move image extraction call from dispatcher to image function
teabolt Aug 8, 2019
3e875bb
Move Keras to second place in supported package list
teabolt Aug 8, 2019
4192939
Remove warnings for 'maybe image' dispatch and conversion to RGBA
teabolt Aug 8, 2019
2bb7ba5
'not supported' error typo
teabolt Aug 8, 2019
991159b
Test 'maybe image' check with both input and model
teabolt Aug 9, 2019
10921be
Add Grad-CAM image to README
teabolt Aug 9, 2019
0cf31fe
Remove line breaking backslash from docstring
teabolt Aug 10, 2019
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

Update tests, docs, tutorial with image argument changes

  • Loading branch information
teabolt committed Aug 7, 2019
commit d1af643c859aef5c2a41a8f4c01ed4ab7d7b430a
@@ -103,7 +103,7 @@ dimensions! Let's resize it:

.. parsed-literal::

<PIL.Image.Image image mode=RGB size=224x224 at 0x7FD4FC485DD8>
<PIL.Image.Image image mode=RGB size=224x224 at 0x7FBF0DDE5A20>



@@ -163,33 +163,18 @@ inputting
.. code:: ipython3

# take back the first image from our 'batch'
display(keras.preprocessing.image.array_to_img(doc[0]))



.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_13_0.png


One last thing, to explain image based models, we need to pass the image
as a PIL object explicitly. However, it must have mode 'RGBA'

.. code:: ipython3

print(im) # current mode

image = im.convert(mode='RGBA') # add alpha channel
image = keras.preprocessing.image.array_to_img(doc[0])
print(image)
display(image)


.. parsed-literal::

<PIL.Image.Image image mode=RGB size=224x224 at 0x7FD4FC485DD8>
<PIL.Image.Image image mode=RGBA size=224x224 at 0x7FD4DB62EF28>
<PIL.Image.Image image mode=RGB size=224x224 at 0x7FBF0CF760F0>



.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_15_1.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_13_1.png


Ready to go!
@@ -241,20 +226,32 @@ for a dog with ELI5:

# we need to pass the network
# the input as a numpy array
# and the corresponding input image (RGBA mode)
eli5.show_prediction(model, doc, image=image)
eli5.show_prediction(model, doc)




.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_21_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_19_0.png



The dog region is highlighted. Makes sense!

Note that here we made a prediction twice. Once when looking at top
predictions, and a second time when passing the model through ELI5.
When explaining image based models, we can optionally pass the image
associated with the input as a Pillow image object. If we don't, the
image will be created from ``doc``. This may not work with custom models
or inputs, in which case it's worth passing the image explicitly.

.. code:: ipython3

eli5.show_prediction(model, doc, image=image)




.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_22_0.png



3. Choosing the target class (target prediction)
------------------------------------------------
@@ -265,7 +262,7 @@ classifier looks to find those objects.
.. code:: ipython3

cat_idx = 282 # ImageNet ID for "tiger_cat" class, because we have a cat in the picture
eli5.show_prediction(model, doc, image=image, targets=[cat_idx]) # pass the class id
eli5.show_prediction(model, doc, targets=[cat_idx]) # pass the class id



@@ -283,8 +280,8 @@ Currently only one class can be explained at a time.

window_idx = 904 # 'window screen'
turtle_idx = 35 # 'mud turtle', some nonsense
display(eli5.show_prediction(model, doc, image=image, targets=[window_idx]))
display(eli5.show_prediction(model, doc, image=image, targets=[turtle_idx]))
display(eli5.show_prediction(model, doc, targets=[window_idx]))
display(eli5.show_prediction(model, doc, targets=[turtle_idx]))



@@ -369,7 +366,7 @@ Rough print but okay. Let's pick a few convolutional layers that are

for l in ['block_2_expand', 'block_9_expand', 'Conv_1']:
print(l)
display(eli5.show_prediction(model, doc, image=image, layer=l)) # we pass the layer as an argument
display(eli5.show_prediction(model, doc, layer=l)) # we pass the layer as an argument


.. parsed-literal::
@@ -417,7 +414,7 @@ better understand what is going on.

.. code:: ipython3

expl = eli5.explain_prediction(model, doc, image=image)
expl = eli5.explain_prediction(model, doc)

Examining the structure of the ``Explanation`` object:

@@ -441,7 +438,7 @@ Examining the structure of the ``Explanation`` object:
[0. , 0. , 0. , 0. , 0. ,
0. , 0.05308531],
[0. , 0. , 0. , 0. , 0. ,
0.01124764, 0.06864655]]))], feature_importances=None, decision_tree=None, highlight_spaces=None, transition_features=None, image=<PIL.Image.Image image mode=RGBA size=224x224 at 0x7FD4DB62EF28>)
0.01124764, 0.06864655]]))], feature_importances=None, decision_tree=None, highlight_spaces=None, transition_features=None, image=<PIL.Image.Image image mode=RGB size=224x224 at 0x7FBEFD7F4080>)


We can check the score (raw value) or probability (normalized score) of
@@ -578,7 +575,7 @@ check the explanation:

# first check the explanation *with* softmax
print('with softmax')
display(eli5.show_prediction(model, doc, image=image))
display(eli5.show_prediction(model, doc))


# remove softmax
@@ -590,7 +587,7 @@ check the explanation:
model = keras.models.load_model('tmp_model_save_rmsoftmax')

print('without softmax')
display(eli5.show_prediction(model, doc, image=image))
display(eli5.show_prediction(model, doc))


.. parsed-literal::
@@ -634,9 +631,10 @@ loading another model and explaining a classification of the same image:
nasnet.preprocess_input(doc2)

print(model.name)
display(eli5.show_prediction(model, doc, image=image))
# note that this model is without softmax
display(eli5.show_prediction(model, doc))
print(model2.name)
display(eli5.show_prediction(model2, doc2, image=image))
display(eli5.show_prediction(model2, doc2))


.. parsed-literal::
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -21,7 +21,7 @@ Currently ELI5 supports :func:`eli5.explain_prediction` for Keras image classifi

The returned :class:`eli5.base.Explanation` instance contains some important objects:

* ``image`` represents the image input into the model. A Pillow image with mode "RGBA".
* ``image`` represents the image input into the model. A Pillow image.

* ``targets`` represents the explanation values for each target class (currently only 1 target is supported). A list of :class:`eli5.base.TargetExplanation` objects with the following attributes set:

@@ -44,13 +44,11 @@ Important arguments to :func:`eli5.explain_prediction` for ``Model`` and ``Seque

* ``image`` Pillow image, corresponds to doc input.

- **Must be passed for image explanations.**

- **Must have mode "RGBA".**
- Image over which to overlay the heatmap.

* ``target_names`` are the names of the output classes.
- *Currently not implemented*.
- If not given, the image will be derived from ``doc`` where possible.

- Useful if ELI5 fails in case you have a custom image model or image input.

* ``targets`` are the output classes to focus on. Possible values include:

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import warnings
from typing import Union, Optional, Callable

import numpy as np # type: ignore
@@ -20,24 +21,24 @@ def format_as_image(expl, # type: Explanation
Format a :class:`eli5.base.Explanation` object as an image.
Note that this formatter requires ``matplotlib`` and ``Pillow`` optional dependencies.
:param Explanation expl:
:class:`eli5.base.Explanation` object to be formatted.
It must have an ``image`` attribute that is a Pillow image with mode "RGBA".
It must also have a ``targets`` attribute, a list of :class:`eli5.base.TargetExplanation` \
It must have an ``image`` attribute with a Pillow image that will be overlaid.
It must have a ``targets`` attribute, a list of :class:`eli5.base.TargetExplanation` \
instances that contain the attribute ``heatmap``, \
a rank 2 numpy array with float values in the interval [0, 1].
Currently ``targets`` must be length 1 (only one target is supported).
:raises TypeError: if ``heatmap`` is not a numpy array.
:raises ValueError: if ``heatmap`` does not contain values as floats in the interval [0, 1].
:raises TypeError: if ``image`` is not a Pillow image.
:raises ValueError: if ``image`` does not have mode 'RGBA'.
:param resampling_filter:
Interpolation ID or Pillow filter to use when resizing the image.
Example filters from PIL.Image
* ``NEAREST``
* ``BOX``
@@ -63,7 +64,7 @@ def format_as_image(expl, # type: Explanation
* ``viridis``
* ``jet``
* ``binary``
See also https://matplotlib.org/gallery/color/colormap_reference.html.
Default is ``matplotlib.cm.viridis`` (green/blue to yellow).
@@ -97,8 +98,9 @@ def format_as_image(expl, # type: Explanation
raise TypeError('Explanation image must be a PIL.Image.Image instance. '
'Got: {}'.format(image))
if image.mode != 'RGBA':
raise ValueError('Explanation image must have mode "RGBA". '
'Got image with mode: %s' % image.mode)
# normalize to 'RGBA'
warnings.warn('Converting image to RGBA.', stacklevel=2)
This conversation was marked as resolved by lopuhin

This comment has been minimized.

Copy link
@lopuhin

lopuhin Aug 8, 2019

Contributor

Why do we want to warn about this? I think usually warnings are used to warn about something that is suboptimal and likely needs to be fixed, but I think in this case it's completely fine to pass any image?

This comment has been minimized.

Copy link
@teabolt

teabolt Aug 8, 2019

Author Contributor

As discussed the two warnings may help the user/developer in case of unexpected behaviour, i.e. image implementation is dispatched but wanted something else, or image is converted to RGBA but the resulting image doesn't look right). However, it's better to replace them with logging (possibly in a separate PR to add logging to the library).

image = image.convert('RGBA')

if not expl.targets:
# no heatmaps
@@ -107,14 +109,14 @@ def format_as_image(expl, # type: Explanation
assert len(expl.targets) == 1
heatmap = expl.targets[0].heatmap
_validate_heatmap(heatmap)

# The order of our operations is: 1. colorize 2. resize
# as opposed: 1. resize 2. colorize

# save the original heatmap values
heatvals = heatmap
# apply colours to the grayscale array
heatmap = _colorize(heatmap, colormap=colormap) # -> rank 3 RGBA array
heatmap = _colorize(heatmap, colormap=colormap) # -> rank 3 RGBA array

# make the alpha intensity correspond to the grayscale heatmap values
# cap the intensity so that it's not too opaque when near maximum value
@@ -133,7 +135,7 @@ def heatmap_to_image(heatmap):
Parameters
----------
heatmap : numpy.ndarray
Rank 2 grayscale ('L') array or rank 3 coloured ('RGBA') array,
Rank 2 grayscale ('L') array or rank 3 coloured ('RGB' or RGBA') array,
with values in interval [0, 1] as floats.
@@ -180,9 +182,9 @@ def _validate_heatmap(heatmap):
ma = np.max(heatmap)
if not (0 <= mi and ma <= 1):
raise ValueError('heatmap must contain float values '
'between 0 and 1 inclusive. '
'Got array with minimum: {} '
'and maximum: {}'.format(mi, ma))
'between 0 and 1 inclusive. '
'Got array with minimum: {} '
'and maximum: {}'.format(mi, ma))


def _colorize(heatmap, colormap):
@@ -201,7 +203,7 @@ def _update_alpha(image_array, starting_array=None, alpha_limit=None):
"""
Update the alpha channel values of an RGBA rank 3 ndarray ``image_array``,
optionally creating the alpha channel from rank 2 ``starting_array``,
and setting upper limit for alpha values (opacity) to ``alpha_limit``.
and setting upper limit for alpha values (opacity) to ``alpha_limit``.
This function modifies ``image_array`` in-place.
"""
@@ -241,11 +243,11 @@ def _cap_alpha(alpha_arr, alpha_limit):

def expand_heatmap(heatmap, image, resampling_filter=Image.LANCZOS):
# type: (np.ndarray, Image, Union[None, int]) -> Image
"""
"""
Resize the ``heatmap`` image array to fit over the original ``image``,
using the specified ``resampling_filter`` method.
using the specified ``resampling_filter`` method.
The heatmap is converted to an image in the process.
Parameters
----------
heatmap : numpy.ndarray
@@ -259,7 +261,7 @@ def expand_heatmap(heatmap, image, resampling_filter=Image.LANCZOS):
See :func:`eli5.format_as_image` for more details on the `resampling_filter` parameter.
:raises TypeError: if ``image`` is not a Pillow image instance.
@@ -281,8 +283,8 @@ def _overlay_heatmap(heatmap, image):
# type: (Image, Image) -> Image
"""
Blend (combine) ``heatmap`` over ``image``,
using alpha channel values appropriately.
Input and output images have mode 'RGBA'.
using alpha channel values appropriately (must have mode `RGBA`).
Output is 'RGBA'.
"""
# note that the order of alpha_composite arguments matters
overlayed_image = Image.alpha_composite(image, heatmap)
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.