Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated Gradients target_fn #523

Merged
merged 7 commits into from Dec 2, 2021

Conversation

jklaise
Copy link
Member

@jklaise jklaise commented Nov 9, 2021

This is a PR partially addressing #492. In particular, it introduces the ability to pass a callback via the target_fn argument to calculate the scalar target dimension from the model output within the method. This is to bypass the requirement of passing target directly to explain when the target of interest may depend on the prediction output as in the case of the highest probability class.

Now both of these are possible to define an explainer:

# current
ig = IntegratedGradients(model)
preds = np.argmax(model(X), axis=1)  # need to define upfront
exp = ig.explain(X, target=preds)

# proposed
from functools import partial
target_fn = partial(np.argmax, axis=1)  # user defined
ig = IntegratedGradients(model, target_fn=target_fn)
exp = ig.explain(X) # no preds needed

TODO:

  • Validate the output of target_fn
  • Tests
  • Docs updates
  • Example updates
  • (Out of scope?) Desired output shape of target_fn. If the output shape of the model is (N,M) as in probabilistic classification, then naturally the output of target_fn is expected to be an (N,) array of integers which specify, for each instance in the batch, which dimension of the output (e.g. maximum probabilitie) to take. However, if the output shape of the model is >2 this would not apply, e.g. if the output shape of the model is (N,M,K), the target_fn would have to specify the indices for both M and K dimensions which would be reflected in the shape of target_fn. That being said, I'm not certain our current implementation of IntegratedGradients supports >2D model outputs so the point may be moot (for now).

To discuss:

  1. I put target_fn as an argument to __init__ as it seemed more natural, however, a case could be made that his belongs in explain as target is part of explain (and so is forward_kwargs—although I would personally like to see this shift to __init__) EDIT: Actually forward_kwargs is data dependent so belongs in explain.
  2. This PR does not resolve the issue of fully specifying the target_fn from outside Python—in particular, the user application must define it. We could approach this (outside of the scope of this PR) by introducing function registries within alibi itself for pre-defined functions using catalogue. It would then look something like this:
# alibi.utils.registries
import catalogue
scalar_reducers = catalogue.create("alibi", "scalar_reducers)

@scalar_reducers.register("argmax")
def argmax(X):
    return np.argmax(X, axis=...)  # TODO: numpy doesn't support multi-axis reduction for argmax
# user code option 1
# this does not fully solve the issue in #492 as the user still has to fetch the function
from alibi.utils.registries import scalar_reducers
target_fn = scalar_reducers['argmax']
explainer = IntegratedGradients(model, target_fn=target_fn)
# user code option 2
# this fully solves the issue in #492 by extending the explainer constructor signature
# to accept strings which are looked up internally in the register
target_fn = "argmax"  # or provided from some config file
explainer = IntegratedGradients(model, target_fn=target_fn)

# inside IntegratedGradients:
    ...
    target_fn = scalar_reducers[target_fn]

We can separate the registry discussion in a separate issue in the future.

It would be nice imo to provide both callback functionality for maximum flexibility as well as strings mapping to alibi built-in functions for ease of use (in this example avoiding the use of partial and defining target_fn manually). Although, we may want to think about separating interfaces that act on pure callbacks (e.g. the public IntegratedGradients interface) and interfaces that act on string names to construct an explainer for greater clarity and separation of concerns.

@jklaise jklaise marked this pull request as draft November 25, 2021 15:42
@jklaise jklaise force-pushed the 492/ig-target-callback-argmax branch from 0bc2673 to fa06597 Compare November 25, 2021 15:45
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jklaise jklaise marked this pull request as ready for review November 29, 2021 17:24
@jklaise jklaise changed the title WIP: Integrated Gradients target_fn poc Integrated Gradients target_fn Nov 29, 2021
@jklaise
Copy link
Member Author

jklaise commented Nov 29, 2021

@sakoush @ascillitoe @arnaudvl ready for review.

This PR does not address the issue of specifying the target via e.g. a string key outside of Python, this can be explored in a follow-up work.

I've kept the example notebooks unchanged but introduced the new usage in the method overview docs. The tests are covering now passing either targets or target_fn, although for the time being I am not considering test cases where either both or neither are passed due to the complexity of internal test logic.

For the example notebook, linking to HTML anchors explicitly defined by <a></a> tags seems broken for now. This can be improved on in the future, especially if converting from notebooks to myst.

@codecov
Copy link

codecov bot commented Nov 29, 2021

Codecov Report

Merging #523 (3f3a748) into master (ee356b1) will increase coverage by 0.10%.
The diff coverage is 97.95%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #523      +/-   ##
==========================================
+ Coverage   82.50%   82.60%   +0.10%     
==========================================
  Files          77       77              
  Lines       10421    10494      +73     
==========================================
+ Hits         8598     8669      +71     
- Misses       1823     1825       +2     
Impacted Files Coverage Δ
alibi/explainers/integrated_gradients.py 89.05% <91.30%> (+0.04%) ⬆️
...libi/explainers/tests/test_integrated_gradients.py 96.92% <100.00%> (+0.48%) ⬆️

Copy link
Member

@sakoush sakoush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments regarding whether it is not going to be a big change to allow a default for target_fn that is not None. If so the explainer can be init on the fly which is I guess useful to have as discussed.

I understand that this will affect downstream logic and testing as you explained but I thought I will mention it.

In any case when we get the registry based / string config this can be addressed so no real urgency I think.

@@ -663,6 +713,7 @@ class IntegratedGradients(Explainer):
def __init__(self,
model: tf.keras.Model,
layer: Optional[tf.keras.layers.Layer] = None,
target_fn: Optional[Callable] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have partial(np.argmax, axis=1) as default instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not as target_fn depends on the output shape of the predictions which can be anything, so it's up to the user to specify, we cannot safely assume a default here.

@@ -761,6 +814,13 @@ def explain(self,
for each feature.

"""
# target handling logic
if self.target_fn and target is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another way to look at it is that target takes precedence I am just think that the user might want to change this at explain time without having to create the expaliner object again?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I didn't want to complicate the logic, but we could also have target override any pre-set target_fn, a warning would need to be emitted. Let me think about it.

Copy link
Contributor

@ascillitoe ascillitoe Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having detectors/explains which can be modified post-init generally seems like a nice idea (especially if init is expensive), but it does open up a little bit of a question of where to stop with this. For example, for many of the offline detectors we could actually have p_val as a kwarg of predict. Perhaps this is best dealt with by opening new issue(s)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's a design discussion. We've had similar discussions in the past, the line between what you would want to change on every explain call vs what you want fixed is very fuzzy.

@@ -722,6 +773,8 @@ def __init__(self,
self._is_np: Optional[bool] = None
self.orig_dummy_input: Optional[Union[list, np.ndarray]] = None

self.target_fn = target_fn
Copy link
Member

@sakoush sakoush Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally do not want to add more dummy calls,

but in this case we have the keras model and target_fn if set, we could check that this is ok at init time?

I guess you left it to explain time to have the check in one place but perhaps if we check target_fn this is towards checking the parameters passed and providing some hints back to the user if things are good.

Copy link
Member Author

@jklaise jklaise Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is quite hard without not knowing what the shape of the inputs is which we only know at explain time. If we wanted to deeply validate user models and callables at construction time it would involve a sizable refactor as we would also need to ask the user for the type of data that can be put through the model.

Practically, this could be solved by asking a user to provide a sample batch of data at __init__ that we can use to validate the model and any callables. In theory this could be an optional kwarg so that people would only use it when they want some guarantees about the functionality of the explainer post __init__ (we wouldn't want to make it mandatory, I think).

(I think, once again, this is where practical Python, i.e. trust users to do the right thing, clashes with production, i.e. we want to validate as much as possible before calling things on production data).

I think this should be discussed in a different place (maybe a new issue?). Would be keen to hear thoughts from @ascillitoe and @arnaudvl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with the sentiments of @sakoush and @jklaise here; as @sakoush hints at, any checks we can do to enable more helpful error messages are great. But as @jklaise alludes to, the complexity needed for this quickly ramps up if we don't assume the inputs pass a certain sanity threshold to begin with. This is made even more challenging by the number of ways tensorflow and pytorch models can fail!

Providing a sample batch of data could be a nice solution, but it might be a little unintuitive. If we were going to go down this route, perhaps it should be a well-documented optional functionality? i.e. have an option to provide sample data so that target_fn can be checked at init.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have the keras model + layer, we can get the shape of the input data shape ((i.e. model.get_layer(layer).input_shape)) and therefore able to construct a valid dummy?

This could also help constructing target_fn if we just refer to argmax as example computation.

but yeah out of scope of this PR perhaps.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, I think the input_shape being present is limited to a subset of tensorflow models as the model may be in a state where input_shape is unknown (there are some hints of this in the integrated_gradients.py module already). More generally, thinking ahead, for torch models input shapes are generally unknown unless provided by the user.

Copy link
Contributor

@ascillitoe ascillitoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. There appear to be two wider design questions to consider (which I've left comments on), but I wonder if these are best left as future issues?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants