Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear gaussian explainer #372

Draft
wants to merge 41 commits into
base: master
Choose a base branch
from
Draft

Linear gaussian explainer #372

wants to merge 41 commits into from

Conversation

martinju
Copy link
Member

Adds two user-functions: explain_lingauss and explain_lingauss_precomputed().
These allows fast computation of Shapley values for purely linear models (i.e. no interactions, quadratic terms etc) under the assumption of a Gaussian distribution for the features.

  • The implementation is based on Sec 2 here: https://arxiv.org/pdf/2006.16234.pdf, but with a somewhat simplified formula for the Tmu and Tx formulae avoiding the need to compute the Q-matrix as it always take the same form.
  • The permutation based Shapley estimation approach is used here instead of the kernelSHAP Shapley estimation approach used elsewhere in the package. Another PR will make that universally available. The
  • The pairwise sampling is applied and always used (currently not an option to disable this).

TODO

  • Implement the permutation sampling in Rcpp.
  • Implement the looping over Tmu/Tmx in Rcpp
  • Add MSE computation? We don't have the v(S) directly computes, and probably don't want to compute it either, but can we simplify the MSE computation in this case under the assumption on the model being linear, but without assuming the features are gaussian (in practice)? I do think that might be possible -- look at the formulas to verify this. Then we can decide whether it is worth implementing or not.
  • Implement grouping. I guess the best way to do this, is to sample group permutations first, and then translate these to the appropriate
  • Update vignette with example on how to use the method
  • Add examples
  • Improve documentation

no antitetic sampling yet
The issue seems to be incorrect weighting of the different S's. I should try to loop through the permutations within the loop instead, extract the relevant S, to then do the computation. Just to see how they are all weighted.
then need to find the weighting per row in S, to then make it more efficient
Will simplify it all creating a function which computes Udiffs for a list of perms (perm_dt) instead of pre-computing stuff and extracting them.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant