Skip to content

Commit

Permalink
[python] fix shap wrapper tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaniecki committed May 5, 2024
1 parent fab61b8 commit f6b0bf7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/dalex/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### development

* added a way to pass `sample_weight` to loss functions in `model_parts()` (variable importance) using `weights` from `dx.Explainer` ([#563](https://github.com/ModelOriented/DALEX/issues/563))
* fixed the visualization of `shap_wrapper` for `shap==0.45.0`

### v1.7.0 (2024-02-28)

Expand Down
9 changes: 8 additions & 1 deletion python/dalex/dalex/wrappers/_shap/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,14 @@ def plot(self, **kwargs):
else:
base_value = self.shap_explainer.expected_value

shap_values = self.result[1] if isinstance(self.result, list) else self.result
if isinstance(self.result, list):
shap_values = self.result[1]
elif isinstance(self.result, np.ndarray):
if len(self.result.shape) == 3:
shap_values = self.result[:, :, 1]
else:
shap_values = self.result

force_plot(base_value=base_value,
shap_values=shap_values,
features=self.new_observation.values,
Expand Down

0 comments on commit f6b0bf7

Please sign in to comment.