Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain predictions of Keras image classifiers (Grad-CAM) #315

Merged
merged 165 commits into from Aug 5, 2019

Conversation

@teabolt
Copy link
Contributor

commented Jun 10, 2019

This PR adds explanations for Keras models that are used to classify images. Specifically we implement Grad-CAM.

For example, the following piece of code:

import keras
import numpy as np
import eli5

# load model
xception = keras.applications.xception.Xception(include_top=True, weights='imagenet', classes=1000)

# load image
im = keras.preprocessing.image.load_img('../eli5_examples/motorcycle.jpg', target_size=(299, 299))
doc = keras.preprocessing.image.img_to_array(im)
doc = np.expand_dims(doc, axis=0)
keras.applications.xception.preprocess_input(doc)

# explain
eli5.show_prediction(xception, doc)

produces this explanation for the class 'motor scooter':
motorcycle_pr_2

The following features are added:

  • Add keras package with explain_prediction_keras() and formatters.image module with format_as_image() (requires matplotlib and Pillow).
  • Add .image attribute to base.Explanation and .heatmap to base.TargetExplanation.
  • Make ipython.show_prediction() dispatch to an image display function for image explanations.

TODO items before this PR is finalized:

  • Resolve reviews
  • Coverage
  • Pass CI
  • Mypy type annotations
  • Docs (formatting, tutorial)
  • Integration and unit tests
teabolt added 30 commits May 27, 2019
…ications preprocess_input. Fix using callable to find target layer
…ain_prediction via approximate attention over area
@@ -13,10 +13,11 @@


def get_weighted_spans(doc, vec, feature_weights):
# type: (Any, Any, FeatureWeights) -> Optional[WeightedSpans]
# type: (Any, Any, Union[FeatureWeights, None]) -> Optional[WeightedSpans]

This comment has been minimized.

Copy link
@kmike

kmike Jul 10, 2019

Contributor

it seems this function requires FeatureWeights to be not None, so maybe it makes sense to keep type signature the same, but move assert to the caller code

This comment has been minimized.

Copy link
@teabolt

teabolt Jul 11, 2019

Author Contributor

Moved the assert to add_weighted_spans. Good one. Strange that https://github.com/TeamHG-Memex/eli5/search?q=get_weighted_spans did not show add_weighted_spans's call to get_weighted_spans.

913d415



def _get_target_prediction(targets, estimator):
# type: (Union[None, list], Model) -> K.variable

This comment has been minimized.

Copy link
@kmike

kmike Jul 10, 2019

Contributor

Optional[List]

This comment has been minimized.

Copy link
@teabolt

teabolt Jul 11, 2019

Author Contributor

Ah I was thinking whether to use this. My thinking was that targets is not actually optional in the parameter list so it might be confusing. But I see that the rest of the library uses Optional (i.e. for functions that could return None) so I will change it.

e8a34f1

eli5/keras/gradcam.py Outdated Show resolved Hide resolved
a valid keras layer name, layer index, or an instance of a Keras layer.
If None, a suitable layer is attempted to be retrieved.
See :func:`eli5.keras._search_layer_backwards` for details.

This comment has been minimized.

Copy link
@kmike

kmike Jul 10, 2019

Contributor

I think if we're documenting a function, it makes sense to make it public - or just document the behavior without mentioning a function.

This comment has been minimized.

Copy link
@teabolt

teabolt Jul 11, 2019

Author Contributor

Good catch. Forgot that I made the function private.

bcaf7ca

An input image as a tensor to ``estimator``,
from which prediction will be done and explained.
For example a ``numpy.ndarray``.

This comment has been minimized.

Copy link
@kmike

kmike Jul 10, 2019

Contributor

are there other supported data types, why "for example"?

This comment has been minimized.

Copy link
@teabolt

teabolt Jul 11, 2019

Author Contributor

I will explicitly mention that numpy arrays are required.

I think it is possible to use other input types (https://github.com/keras-team/keras/blob/ed07472bc5fc985982db355135d37059a1f887a9/keras/engine/training.py#L1315), i.e. tensorflow tensor. However, I haven't tested with other types and I think I have some numpy dependencies in my code. Adding more input types could be a separate GitHub issue?

9d2d22a

@teabolt

This comment has been minimized.

Copy link
Contributor Author

commented Jul 12, 2019

Currently the parameter resampling_filter (previously called interpolation) of eli5.format_as_image() takes an integer from https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters, i.e. the user passes something like resampling_filter=PIL.Image.BOX. It could be clearer to take the filter as a string, i.e. let the user say resampling_filter="BOX"?

@lopuhin

This comment has been minimized.

Copy link
Contributor

commented Jul 29, 2019

resampling_filter=PIL.Image.BOX. It could be clearer to take the filter as a string, i.e. let the user say resampling_filter="BOX"?

I think both options are fine, to me a constant looks a bit better than a string, and I think it's fine to use a PIL constant here.

Copy link
Contributor

left a comment

@teabolt sorry for a long review - there is just one minor thing which is preventing the merge now, updating the some docs after some attributes were moved to TargetExplanation:

eli5/keras/explain_prediction.py Outdated Show resolved Hide resolved
docs/source/libraries/keras.rst Show resolved Hide resolved
eli5/formatters/image.py Show resolved Hide resolved
eli5/formatters/image.py Outdated Show resolved Hide resolved
@lopuhin
lopuhin approved these changes Aug 5, 2019
Copy link
Contributor

left a comment

Perfect, thanks @teabolt 👍

@lopuhin

This comment has been minimized.

Copy link
Contributor

commented Aug 5, 2019

Thanks for a great new feature @teabolt , and thanks for review @kmike , merging 🎉

@lopuhin lopuhin merged commit 2497ec3 into TeamHG-Memex:master Aug 5, 2019
2 checks passed
2 checks passed
codecov/patch 98.64% of diff hit (target 97.19%)
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@@ -45,6 +45,9 @@ following machine learning frameworks and packages:
* :ref:`library-sklearn-crfsuite`. ELI5 allows to check weights of
sklearn_crfsuite.CRF models.

* :ref:`library-keras` - explain predictions of image classifiers
via Grad-CAM visualizations.

This comment has been minimized.

Copy link
@kmike

kmike Aug 5, 2019

Contributor

Sorry for a late comment: could you please copy overview.rst changes to README file in the repo root?

This comment has been minimized.

Copy link
@teabolt

teabolt Aug 6, 2019

Author Contributor

Will do that in #329. Thanks

@kmike

This comment has been minimized.

Copy link
Contributor

commented Aug 5, 2019

Thanks @teabolt and @lopuhin, great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants
You can’t perform that action at this time.