New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrated Gradients target_fn
#523
Integrated Gradients target_fn
#523
Conversation
0bc2673
to
fa06597
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
target_fn
@sakoush @ascillitoe @arnaudvl ready for review. This PR does not address the issue of specifying the target via e.g. a string key outside of Python, this can be explored in a follow-up work. I've kept the example notebooks unchanged but introduced the new usage in the method overview docs. The tests are covering now passing either For the example notebook, linking to HTML anchors explicitly defined by |
Codecov Report
@@ Coverage Diff @@
## master #523 +/- ##
==========================================
+ Coverage 82.50% 82.60% +0.10%
==========================================
Files 77 77
Lines 10421 10494 +73
==========================================
+ Hits 8598 8669 +71
- Misses 1823 1825 +2
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments regarding whether it is not going to be a big change to allow a default for target_fn
that is not None. If so the explainer can be init on the fly which is I guess useful to have as discussed.
I understand that this will affect downstream logic and testing as you explained but I thought I will mention it.
In any case when we get the registry based / string config this can be addressed so no real urgency I think.
@@ -663,6 +713,7 @@ class IntegratedGradients(Explainer): | |||
def __init__(self, | |||
model: tf.keras.Model, | |||
layer: Optional[tf.keras.layers.Layer] = None, | |||
target_fn: Optional[Callable] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we have partial(np.argmax, axis=1)
as default instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not as target_fn
depends on the output shape of the predictions which can be anything, so it's up to the user to specify, we cannot safely assume a default here.
@@ -761,6 +814,13 @@ def explain(self, | |||
for each feature. | |||
|
|||
""" | |||
# target handling logic | |||
if self.target_fn and target is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another way to look at it is that target
takes precedence I am just think that the user might want to change this at explain time without having to create the expaliner object again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point, I didn't want to complicate the logic, but we could also have target
override any pre-set target_fn
, a warning would need to be emitted. Let me think about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having detectors/explains which can be modified post-init generally seems like a nice idea (especially if init is expensive), but it does open up a little bit of a question of where to stop with this. For example, for many of the offline detectors we could actually have p_val
as a kwarg of predict
. Perhaps this is best dealt with by opening new issue(s)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, it's a design discussion. We've had similar discussions in the past, the line between what you would want to change on every explain
call vs what you want fixed is very fuzzy.
@@ -722,6 +773,8 @@ def __init__(self, | |||
self._is_np: Optional[bool] = None | |||
self.orig_dummy_input: Optional[Union[list, np.ndarray]] = None | |||
|
|||
self.target_fn = target_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally do not want to add more dummy calls,
but in this case we have the keras model and target_fn if set, we could check that this is ok at init time?
I guess you left it to explain time to have the check in one place but perhaps if we check target_fn
this is towards checking the parameters passed and providing some hints back to the user if things are good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is quite hard without not knowing what the shape of the inputs is which we only know at explain
time. If we wanted to deeply validate user models and callables at construction time it would involve a sizable refactor as we would also need to ask the user for the type of data that can be put through the model.
Practically, this could be solved by asking a user to provide a sample batch of data at __init__
that we can use to validate the model and any callables. In theory this could be an optional kwarg
so that people would only use it when they want some guarantees about the functionality of the explainer post __init__
(we wouldn't want to make it mandatory, I think).
(I think, once again, this is where practical Python, i.e. trust users to do the right thing, clashes with production, i.e. we want to validate as much as possible before calling things on production data).
I think this should be discussed in a different place (maybe a new issue?). Would be keen to hear thoughts from @ascillitoe and @arnaudvl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with the sentiments of @sakoush and @jklaise here; as @sakoush hints at, any checks we can do to enable more helpful error messages are great. But as @jklaise alludes to, the complexity needed for this quickly ramps up if we don't assume the inputs pass a certain sanity threshold to begin with. This is made even more challenging by the number of ways tensorflow and pytorch models can fail!
Providing a sample batch of data could be a nice solution, but it might be a little unintuitive. If we were going to go down this route, perhaps it should be a well-documented optional functionality? i.e. have an option to provide sample data so that target_fn
can be checked at init.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have the keras model + layer, we can get the shape of the input data shape ((i.e. model.get_layer(layer).input_shape
)) and therefore able to construct a valid dummy?
This could also help constructing target_fn if we just refer to argmax
as example computation.
but yeah out of scope of this PR perhaps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, I think the input_shape
being present is limited to a subset of tensorflow
models as the model may be in a state where input_shape
is unknown (there are some hints of this in the integrated_gradients.py
module already). More generally, thinking ahead, for torch
models input shapes are generally unknown unless provided by the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. There appear to be two wider design questions to consider (which I've left comments on), but I wonder if these are best left as future issues?
3ee7ff3
to
3f3a748
Compare
This is a PR partially addressing #492. In particular, it introduces the ability to pass a callback via the
target_fn
argument to calculate the scalar target dimension from the model output within the method. This is to bypass the requirement of passingtarget
directly toexplain
when thetarget
of interest may depend on the prediction output as in the case of the highest probability class.Now both of these are possible to define an explainer:
TODO:
target_fn
target_fn
. If the output shape of the model is(N,M)
as in probabilistic classification, then naturally the output oftarget_fn
is expected to be an(N,)
array of integers which specify, for each instance in the batch, which dimension of the output (e.g. maximum probabilitie) to take. However, if the output shape of the model is>2
this would not apply, e.g. if the output shape of the model is(N,M,K)
, thetarget_fn
would have to specify the indices for bothM
andK
dimensions which would be reflected in the shape oftarget_fn
. That being said, I'm not certain our current implementation ofIntegratedGradients
supports>2D
model outputs so the point may be moot (for now).To discuss:
target_fn
as an argument to__init__
as it seemed more natural, however, a case could be made that his belongs inexplain
astarget
is part of explain(and so isEDIT: Actuallyforward_kwargs
—although I would personally like to see this shift to__init__
)forward_kwargs
is data dependent so belongs inexplain
.target_fn
from outside Python—in particular, the user application must define it. We could approach this (outside of the scope of this PR) by introducing function registries withinalibi
itself for pre-defined functions using catalogue. It would then look something like this:We can separate the registry discussion in a separate issue in the future.
It would be nice imo to provide both callback functionality for maximum flexibility as well as strings mapping to
alibi
built-in functions for ease of use (in this example avoiding the use ofpartial
and definingtarget_fn
manually). Although, we may want to think about separating interfaces that act on pure callbacks (e.g. the publicIntegratedGradients
interface) and interfaces that act on string names to construct an explainer for greater clarity and separation of concerns.