Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Fix tf lstm #3419

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open

WIP: Fix tf lstm #3419

wants to merge 14 commits into from

Conversation

CloseChoice
Copy link
Collaborator

@CloseChoice CloseChoice commented Dec 2, 2023

Overview

Closes #3344, #3343

Description of the changes proposed in this pull request:
The problem is that the tensorflow ops were not captured in eager mode (so self._init_between_tensors was not executed which resulted in an error cause self.in_between_ops was not set). This is fixed by introducing _init_between_tensors_eager method, that executes the graph once for the data input and captures all ops that are called. From there on the flow continues as before. Since tensorflow2 seems to introduce a couple more ops, these also need to be added to the op_handlers.

Checklist

I would suggest the following steps in the given order:

  1. The failing test test_tf_keras_lstm2 might be an indication that one of the recently added ops shouldn't be passed through. Understand what is going on here and how we can do this better.
  2. Check in paper if additivity is conserved with this approach (I guess it should)
  3. Get more test cases to make sure that the issue is correctly fixed (at least 5 different tests, search in previous issues to find some and include feedback from the community). Inspired by this comment, we should at least test RNN, GRU and LSTM layers.
  4. Clean up this PR and make it ready for review. Decide on the tests that we really want to include in the test suite on a permanent basis.
  5. check out if this fixes The SHAP explanations do not sum up to the model's output!  #2765

Later:

  • All pre-commit checks pass.
  • Unit tests added (if fixing a bug or adding a new feature)

Note: Feel free to review and also post some previously failing examples in here

@CloseChoice
Copy link
Collaborator Author

CloseChoice commented Dec 7, 2023

@connortann: I would like to ask for your advice. So this is a pretty longstanding issue, and currently I am capable of evaluating shap values for LSTM, GRU and SimpleRNN layers as long as the neural net is linear (output of layer t is input of layer t+1). But there is a more general case where e.g. multiple inputs can be concatenated which I currently struggle with. I am confident to fix this problem in the upcoming month. But should I prepare a PR where the linear NNs are working and create a seperate one for nonlinear NNs? What is your thought on that?

@CloseChoice CloseChoice marked this pull request as ready for review December 11, 2023 16:28
@connortann
Copy link
Collaborator

Thanks for taking on this issue, it would be great to get those issues addressed.

In terms of how to arrange the PRs, I don't really have anything specific to say about this PR but only general advice. It's helpful to break PRs down into small incremental changes wherever possible. It's up to your judgement on how best to split any large PRs.

Your emphasis on adding test cases looks very prudent - this seems like it will be critical, given that there is a huge spaec of possible models that DeepExplainer has to be able to explain.

@CloseChoice
Copy link
Collaborator Author

The errors seem mainly to happen in 2 tests and whether the shap values sum up as expected is heavily dependent on the input. Not sure if I should mark this tests as xfail for now and create a seperate issue. The thing here is that we have a lot of issues where people are complaining about the Deepexplainer outputs not summing up as expected (not passing assert_additivity) and this PR is still a clear improvement over the status quo even without fixing the underlying assert_additivity issue (which I suspect might be a mixture of rounding issues + some wrong assignments in the ophandler). Any feedback on this is very welcome

@raghavchalapathy
Copy link

Hi @CloseChoice @connortann I am very closely following this issue and I am happy to test the releases or fixes done Kindly let me know if you need any support from my side happy to contribute to fix this issue

@CloseChoice
Copy link
Collaborator Author

CloseChoice commented Dec 12, 2023

Hi @CloseChoice @connortann I am very closely following this issue and I am happy to test the releases or fixes done Kindly let me know if you need any support from my side happy to contribute to fix this issue

Thanks for the reply. Your feedback is most welcome. If you could test this branch on data and tell us whether you realise if the results make sense or if there are strong inconsistencies, that would be amazing.

@raghavchalapathy
Copy link

Hi @CloseChoice

Sure, Happy to do so.

Just double checking I am good If I run this setup:

this experiment --> https://www.kaggle.com/code/billlucas/explaining-cnn-lstm-using-shap/notebook
With this branch code --> Fix-tf-lstm (Branch)
with python version 3.11 and tensorflow latest version

Is my understanding correct ? If there any specific versions I need to use Please point me to the url where can I refer these details
Thank you

@CloseChoice
Copy link
Collaborator Author

CloseChoice commented Dec 13, 2023

Hi @CloseChoice

Sure, Happy to do so.

Just double checking I am good If I run this setup:

this experiment --> https://www.kaggle.com/code/billlucas/explaining-cnn-lstm-using-shap/notebook With this branch code --> Fix-tf-lstm (Branch) with python version 3.11 and tensorflow latest version

Is my understanding correct ? If there any specific versions I need to use Please point me to the url where can I refer these details Thank you

Thanks a lot for helping on this. I really appreciate it. That looks good. Hope with this fix, we are capable to run any tensorflow version above 2.4 aswell, so feel free to use the latest version (our test suite will test against the latest either way). I did not test all layers that I can see in that examples, so am pretty curious about the results.

@ANeeK181
Copy link

hi @CloseChoice

i am using it for some LSTM code, as i couldn't compute SHAP values. Now code is working, but I am getting the gradients as zero. Looking into it, I found that in function phi_symbolic (line 365) and then in function grad_graph (line 352), x_grad is zero. Now i see you are still calling self._init_between_tensors(out.op, shap_rAnD) instead of self._init_between_tensors_eager(out.op, shap_rAnD, data..?)? Maybe that is making gradient go to zero?

@ANeeK181
Copy link

hi @CloseChoice

i am using it for some LSTM code, as i couldn't compute SHAP values. Now code is working, but I am getting the gradients as zero. Looking into it, I found that in function phi_symbolic (line 365) and then in function grad_graph (line 352), x_grad is zero. Now i see you are still calling self._init_between_tensors(out.op, shap_rAnD) instead of self._init_between_tensors_eager(out.op, shap_rAnD, data..?)? Maybe that is making gradient go to zero?

i also did

x_grad = tape.gradient(
                        out,
                        shap_rAnD,
                        unconnected_gradients=tf.UnconnectedGradients.NONE,
                    )

to see if there is issue in network, but gradient is still zero, which means that gradients are going to zero

@CloseChoice
Copy link
Collaborator Author

hi @CloseChoice
i am using it for some LSTM code, as i couldn't compute SHAP values. Now code is working, but I am getting the gradients as zero. Looking into it, I found that in function phi_symbolic (line 365) and then in function grad_graph (line 352), x_grad is zero. Now i see you are still calling self._init_between_tensors(out.op, shap_rAnD) instead of self._init_between_tensors_eager(out.op, shap_rAnD, data..?)? Maybe that is making gradient go to zero?

i also did

x_grad = tape.gradient(
                        out,
                        shap_rAnD,
                        unconnected_gradients=tf.UnconnectedGradients.NONE,
                    )

to see if there is issue in network, but gradient is still zero, which means that gradients are going to zero

Thanks for looking into this. I also found that all shap values are zero (due to the gradients being zero as you found out). The problem here is that tensorflow is hiding the exact gradients from us and just exposed a PartitionedCall to us but we need to get the exact ops that are called. I'll try to find a workaround for this.

@ANeeK181
Copy link

What about call to self._init_between_tensors(out.op, shap_rAnD) in grad_graph function? I changed it to self._init_between_tensors_eager(out.op, shap_rAnD, data..?) but then got other errors. maybe if that is fixed, gradients would be correct

@CloseChoice
Copy link
Collaborator Author

What about call to self._init_between_tensors(out.op, shap_rAnD) in grad_graph function? I changed it to self._init_between_tensors_eager(out.op, shap_rAnD, data..?) but then got other errors. maybe if that is fixed, gradients would be correct

Nope, that won't help since the _init_between_tensors_eager is also just catching the PartitionedCell but we need to get the ops underlying this cell.

@zlds123
Copy link

zlds123 commented Apr 17, 2024

@CloseChoice Do we have a status update on this? Would really appreciate it because my model is using some tf2 features and require shap lstm explainer.

@CloseChoice
Copy link
Collaborator Author

@CloseChoice Do we have a status update on this? Would really appreciate it because my model is using some tf2 features and require shap lstm explainer.

Hey, thanks for having your eyes on this. I do not have worked on this for quite a while. The problems lay somewhere deep in tensorflow hiding their API from us. Can't give any timeframe on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SHAP not working with LSTM! The SHAP explanations do not sum up to the model's output!
5 participants