Skip to content

Commit

Permalink
Merge 27e59aa into 6a80881
Browse files Browse the repository at this point in the history
  • Loading branch information
mgarod committed Oct 16, 2020
2 parents 6a80881 + 27e59aa commit 24f35fc
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 2 deletions.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 66 additions & 0 deletions tests/test_model_selection/test_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,72 @@ def test_with_fitted(self):
oz.fit(X, y)
mockfit.assert_called_once_with(X, y)

def test_topn_stacked(self):
"""
Test stack plot with only the three most important features by sum of
each feature's importance across all classes
"""
X, y = load_iris(True)

viz = FeatureImportances(
LogisticRegression(solver="liblinear", random_state=222),
stack=True, topn=3
)
viz.fit(X, y)
viz.finalize()

npt.assert_equal(viz.feature_importances_.shape, (3, 3))
# Appveyor and Linux conda non-text-based differences
self.assert_images_similar(viz, tol=17.5)

def test_topn_negative_stacked(self):
"""
Test stack plot with only the three least important features by sum of
each feature's importance across all classes
"""
X, y = load_iris(True)

viz = FeatureImportances(
LogisticRegression(solver="liblinear", random_state=222),
stack=True, topn=-3
)
viz.fit(X, y)
viz.finalize()

npt.assert_equal(viz.feature_importances_.shape, (3, 3))
# Appveyor and Linux conda non-text-based differences
self.assert_images_similar(viz, tol=17.5)

def test_topn(self):
"""
Test plot with only top three important features by absolute value
"""
X, y = load_iris(True)

viz = FeatureImportances(
GradientBoostingClassifier(random_state=42), topn=3
)
viz.fit(X, y)
viz.finalize()

# Appveyor and Linux conda non-text-based differences
self.assert_images_similar(viz, tol=17.5)

def test_topn_negative(self):
"""
Test plot with only the three least important features by absolute value
"""
X, y = load_iris(True)

viz = FeatureImportances(
GradientBoostingClassifier(random_state=42), topn=-3
)
viz.fit(X, y)
viz.finalize()

# Appveyor and Linux conda non-text-based differences
self.assert_images_similar(viz, tol=17.5)


##########################################################################
## Mock Estimator
Expand Down
48 changes: 46 additions & 2 deletions yellowbrick/model_selection/importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from yellowbrick.base import ModelVisualizer
from yellowbrick.style.colors import resolve_colors
from yellowbrick.utils import is_dataframe, is_classifier
from yellowbrick.exceptions import YellowbrickTypeError, NotFitted, YellowbrickWarning
from yellowbrick.exceptions import YellowbrickTypeError, NotFitted, YellowbrickWarning, YellowbrickValueError

##########################################################################
## Feature Visualizer
Expand Down Expand Up @@ -92,6 +92,10 @@ class FeatureImportances(ModelVisualizer):
modified. If 'auto' (default), a helper method will check if the estimator
is fitted before fitting it again.
topn : int, default=None
Display only the top N results with a positive integer, or the bottom N
results with a negative integer. If None or 0, all results are shown.
kwargs : dict
Keyword arguments that are passed to the base class and may influence
the visualization as defined in other Visualizers.
Expand Down Expand Up @@ -128,6 +132,7 @@ def __init__(
colors=None,
colormap=None,
is_fitted="auto",
topn=None,
**kwargs
):
# Initialize the visualizer bases
Expand All @@ -144,6 +149,7 @@ def __init__(
stack=stack,
colors=colors,
colormap=colormap,
topn=topn
)

def fit(self, X, y=None, **kwargs):
Expand Down Expand Up @@ -218,12 +224,33 @@ def fit(self, X, y=None, **kwargs):
else:
self.features_ = np.array(self.labels)

if self.topn and self.topn > self.features_.shape[0]:
raise YellowbrickValueError(
"topn '{}' cannot be greater than the number of "
"features '{}'".format(self.topn, self.features_.shape[0])
)

# Sort the features and their importances
if self.stack:
sort_idx = np.argsort(np.mean(self.feature_importances_, 0))
if self.topn:
abs_sort_idx = np.argsort(
np.sum(np.absolute(self.feature_importances_), 0)
)
sort_idx = self._reduce_topn(abs_sort_idx)
else:
sort_idx = np.argsort(np.mean(self.feature_importances_, 0))

self.features_ = self.features_[sort_idx]
self.feature_importances_ = self.feature_importances_[:, sort_idx]
else:
if self.topn:
abs_sort_idx = np.argsort(np.absolute(self.feature_importances_))
abs_sort_idx = self._reduce_topn(abs_sort_idx)

self.features_ = self.features_[abs_sort_idx]
self.feature_importances_ = self.feature_importances_[abs_sort_idx]

# Sort features by value (sorting a second time if topn)
sort_idx = np.argsort(self.feature_importances_)
self.features_ = self.features_[sort_idx]
self.feature_importances_ = self.feature_importances_[sort_idx]
Expand Down Expand Up @@ -346,6 +373,17 @@ def _is_fitted(self):
"""
return hasattr(self, "feature_importances_") and hasattr(self, "features_")

def _reduce_topn(self, arr):
"""
Return only the top or bottom N items within a sliceable array/list.
Assumes that arr is in ascending order.
"""
if self.topn > 0:
arr = arr[-self.topn:]
elif self.topn < 0:
arr = arr[:-self.topn]
return arr

##########################################################################
## Quick Method
Expand All @@ -365,6 +403,7 @@ def feature_importances(
colors=None,
colormap=None,
is_fitted="auto",
topn=None,
show=True,
**kwargs
):
Expand Down Expand Up @@ -431,6 +470,10 @@ def feature_importances(
call ``plt.savefig`` from this signature, nor ``clear_figure``. If False, simply
calls ``finalize()``
topn : int, default=None
Display only the top N results with a positive integer, or the bottom N
results with a negative integer. If None or 0, all results are shown.
kwargs : dict
Keyword arguments that are passed to the base class and may influence
the visualization as defined in other Visualizers.
Expand All @@ -452,6 +495,7 @@ def feature_importances(
colors=colors,
colormap=colormap,
is_fitted=is_fitted,
topn=topn,
**kwargs
)

Expand Down

0 comments on commit 24f35fc

Please sign in to comment.