Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: SHAP DeepExplainer cannot get SHAP values from TorchScript model #3532

Open
3 of 4 tasks
DarrelYee opened this issue Mar 1, 2024 · 1 comment
Open
3 of 4 tasks
Labels
awaiting feedback Indicates that further information is required from the issue creator bug Indicates an unexpected problem or unintended behaviour deep explainer Relating to DeepExplainer, tensorflow or pytorch

Comments

@DarrelYee
Copy link

Issue Description

DeepExplainer currently seems unable to handle a Pytorch model loaded from TorchScript, and will throw RuntimeError: register_forward_hook is not supported on ScriptModules. DeepExplainer has no problem with the original source nn.Module, but throws the error when its converted to ScriptModule. Reconstituting a handler nn.Module from ScriptModule gives the same error.

I'm using 0.42.1, but the problem occurs on the latest release as well.

Minimal Reproducible Example

# Setup a dummy model and convert to TorchScript
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)
    
simple_model = SimpleModel()
ts = torch.jit.script(simple_model)

eg_input_data = torch.arange(0, 10, 0.1, dtype = torch.float32).reshape((10,10))

# Running this gives RuntimeError: register_forward_hook is not supported on ScriptModules
explainer = shap.DeepExplainer(simple_model, eg_input_data)
print(explainer.shap_values(eg_input_data))

# Reconstituting the ts as an nn.Module gives the same error
class ReconModel(nn.Module):
    def __init__(self, ts):
        super().__init__()
        self.ts = ts

    def forward(self, x):
        return self.ts(x)

recon_model = ReconModel(ts)

explainer = shap.DeepExplainer(recon_model, torch.arange(0, 1, 0.1, dtype = torch.float32).reshape((-1,10)))
print(explainer.shap_values(eg_input_data))

Traceback

RuntimeError                              Traceback (most recent call last)
<ipython-input-43-fd62703e7cd4> in <module>
     28 
     29 explainer = shap.DeepExplainer(recon_model, torch.arange(0, 1, 0.1, dtype = torch.float32).reshape((-1,10)))
---> 30 print(explainer.shap_values(eg_input_data))

~/.local/lib/python3.7/site-packages/shap/explainers/_deep/__init__.py in shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
    122             were chosen as "top".
    123         """
--> 124         return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)

~/.local/lib/python3.7/site-packages/shap/explainers/_deep/deep_pytorch.py in shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
    164 
    165         # add the gradient handles
--> 166         handles = self.add_handles(self.model, add_interim_values, deeplift_grad)
    167         if self.interim:
    168             self.add_target_handle(self.layer)

~/.local/lib/python3.7/site-packages/shap/explainers/_deep/deep_pytorch.py in add_handles(self, model, forward_handle, backward_handle)
     77         if model_children:
     78             for child in model_children:
---> 79                 handles_list.extend(self.add_handles(child, forward_handle, backward_handle))
     80         else:  # leaves
     81             handles_list.append(model.register_forward_hook(forward_handle))

~/.local/lib/python3.7/site-packages/shap/explainers/_deep/deep_pytorch.py in add_handles(self, model, forward_handle, backward_handle)
     77         if model_children:
     78             for child in model_children:
---> 79                 handles_list.extend(self.add_handles(child, forward_handle, backward_handle))
     80         else:  # leaves
     81             handles_list.append(model.register_forward_hook(forward_handle))

~/.local/lib/python3.7/site-packages/shap/explainers/_deep/deep_pytorch.py in add_handles(self, model, forward_handle, backward_handle)
     79                 handles_list.extend(self.add_handles(child, forward_handle, backward_handle))
     80         else:  # leaves
---> 81             handles_list.append(model.register_forward_hook(forward_handle))
     82             handles_list.append(model.register_backward_hook(backward_handle))
     83         return handles_list

~/.local/lib/python3.7/site-packages/torch/jit/_script.py in fail(self, *args, **kwargs)
    941     def _make_fail(name):
    942         def fail(self, *args, **kwargs):
--> 943             raise RuntimeError(name + " is not supported on ScriptModules")
    944 
    945         return fail

RuntimeError: register_forward_hook is not supported on ScriptModules

Expected Behavior

No response

Bug report checklist

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest release of shap.
  • I have confirmed this bug exists on the master branch of shap.
  • I'd be interested in making a PR to fix this bug

Installed Versions

0.42.1

@DarrelYee DarrelYee added the bug Indicates an unexpected problem or unintended behaviour label Mar 1, 2024
@connortann connortann added the deep explainer Relating to DeepExplainer, tensorflow or pytorch label Mar 1, 2024
@CloseChoice
Copy link
Collaborator

CloseChoice commented Mar 9, 2024

Thanks for reporting the issue.

This seems like an inherent limitation of pytorch. We need to overwrite the gradients of the model and do this via hooks. If this is not possible with TorchScript models (or if there is no known acceptable workaround) we won't support this.

I leave this open for a while but if there is no further progress, we'll close this issue as a NO FIX.

Edit: Two ideas to check:

  • does captum support TorchScript models?
  • we should at least throw an error to explain why we don't support this.

@CloseChoice CloseChoice added the awaiting feedback Indicates that further information is required from the issue creator label Mar 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting feedback Indicates that further information is required from the issue creator bug Indicates an unexpected problem or unintended behaviour deep explainer Relating to DeepExplainer, tensorflow or pytorch
Projects
None yet
Development

No branches or pull requests

3 participants