Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh: deprecate legacy plots and code #3209

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions shap/plots/_bar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import matplotlib.pyplot as pl
import numpy as np
import scipy
from sklearn.utils import deprecated

from .. import Cohorts, Explanation
from ..utils import format_value, ordinal_str
Expand Down Expand Up @@ -372,6 +373,7 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu

# return max(left_val, right_val) + 1, max(left_sum, right_sum)

@deprecated("bar_legacy is being deprecated in Version 0.43.0. This will be removed in Version 0.44")
def bar_legacy(shap_values, features=None, feature_names=None, max_display=None, show=True):

# unwrap pandas series
Expand Down
2 changes: 2 additions & 0 deletions shap/plots/_beeswarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import scipy.sparse
import scipy.spatial
from scipy.stats import gaussian_kde
from sklearn.utils import deprecated

from .. import Explanation
from ..utils import safe_isinstance
Expand Down Expand Up @@ -435,6 +436,7 @@ def is_color_map(color):

# TODO: remove unused title argument / use title argument
# TODO: Add support for hclustering based explanations where we sort the leaf order by magnitude and then show the dendrogram to the left
@deprecated("summary_legacy is being deprecated in Version 0.43.0. This will be removed in Version 0.44")
def summary_legacy(shap_values, features=None, feature_names=None, max_display=None, plot_type=None,
color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True,
color_bar=True, plot_size="auto", layered_violin_max_num_bins=20, class_names=None,
Expand Down
3 changes: 2 additions & 1 deletion shap/plots/_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import matplotlib
import matplotlib.pyplot as pl
import numpy as np
from sklearn.utils import deprecated

from .._explanation import Explanation
from ..utils import approximate_interactions, convert_name
Expand Down Expand Up @@ -470,7 +471,7 @@ def scatter(shap_values, color="#1E88E5", hist=True, axis_color="#333333", cmap=
warnings.simplefilter("ignore", RuntimeWarning)
pl.show()


@deprecated("dependence_legacy is being deprecated in Version 0.43.0. This will be removed in Version 0.44")
def dependence_legacy(ind, shap_values=None, features=None, feature_names=None, display_features=None,
interaction_index="auto",
color="#1E88E5", axis_color="#333333", cmap=None,
Expand Down
3 changes: 2 additions & 1 deletion shap/plots/_waterfall.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils import deprecated

from .. import Explanation
from ..utils import format_value, safe_isinstance
Expand Down Expand Up @@ -317,7 +318,7 @@ def waterfall(shap_values, max_display=10, show=True):
else:
return plt.gcf()


@deprecated("waterfall_legacy is being deprecated in Version 0.43.0. This will be removed in Version 0.44")
def waterfall_legacy(expected_value, shap_values=None, features=None, feature_names=None, max_display=10, show=True):
""" Plots an explanation of a single prediction as a waterfall plot.

Expand Down
2 changes: 2 additions & 0 deletions shap/utils/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import scipy.sparse
from sklearn.cluster import KMeans
from sklearn.impute import SimpleImputer
from sklearn.utils import deprecated


@deprecated("kmeans is being deprecated in Version 0.43.0. This will be removed in Version 0.44")
def kmeans(X, k, round_values=True):
""" Summarize a dataset with k mean samples weighted by the number of data points they
each represent.
Expand Down
6 changes: 6 additions & 0 deletions tests/plots/test_beeswarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,9 @@ def test_simple_beeswarm(explainer):
shap.plots.beeswarm(shap_values)
plt.tight_layout()
return fig

def test_summary_legacy_deprecation_warning(explainer):
shap_values = explainer(explainer.data)
plt.figure()
with pytest.warns(FutureWarning, match="summary_legacy is being deprecated in Version 0.43.0"):
shap.plots._beeswarm.summary_legacy(shap_values)
5 changes: 5 additions & 0 deletions tests/plots/test_dependence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib
import numpy as np
import pytest

matplotlib.use('Agg')
import shap # pylint: disable=wrong-import-position
Expand All @@ -14,3 +15,7 @@ def test_random_dependence_no_interaction():
""" Make sure a dependence plot does not crash when we are not showing interactions.
"""
shap.dependence_plot(0, np.random.randn(20, 5), np.random.randn(20, 5), show=False, interaction_index=None)

def test_dependence_legacy_deprecation_warning(explainer):
with pytest.warns(FutureWarning, match="dependence_legacy is being deprecated in Version 0.43.0. This will be removed in Version 0.44"):
shap.plots._scatter.dependence_legacy(0, np.random.randn(20, 5), np.random.randn(20, 5), show=False, interaction_index=None)
5 changes: 5 additions & 0 deletions tests/plots/test_waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,8 @@ def test_waterfall_plot_for_decision_tree_explanation():
explainer = shap.TreeExplainer(model)
explanation = explainer(X)
shap.plots.waterfall(explanation[0], show=False)

def test_waterfall_legacy_deprecation_warning(explainer):
shap_values = explainer.shap_values(explainer.data)
with pytest.warns(FutureWarning, match="waterfall_legacy is being deprecated in Version 0.43.0. This will be removed in Version 0.44"):
shap.plots._waterfall.waterfall_legacy(explainer.expected_value, shap_values[0])