Skip to content

Commit

Permalink
Merge be4ea5f into 378672b
Browse files Browse the repository at this point in the history
  • Loading branch information
mgarod committed Oct 5, 2020
2 parents 378672b + be4ea5f commit 9ce2b8f
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion yellowbrick/model_selection/importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from yellowbrick.base import ModelVisualizer
from yellowbrick.style.colors import resolve_colors
from yellowbrick.utils import is_dataframe, is_classifier
from yellowbrick.exceptions import YellowbrickTypeError, NotFitted, YellowbrickWarning
from yellowbrick.exceptions import YellowbrickTypeError, NotFitted, YellowbrickWarning, YellowbrickValueError

##########################################################################
## Feature Visualizer
Expand Down Expand Up @@ -92,6 +92,10 @@ class FeatureImportances(ModelVisualizer):
modified. If 'auto' (default), a helper method will check if the estimator
is fitted before fitting it again.
top_n : int, default=None
Display only the top N results with a positive integer, or the bottom N
results with a negative integer. If None or 0, all results are shown.
kwargs : dict
Keyword arguments that are passed to the base class and may influence
the visualization as defined in other Visualizers.
Expand Down Expand Up @@ -128,6 +132,7 @@ def __init__(
colors=None,
colormap=None,
is_fitted="auto",
top_n=None,
**kwargs
):
# Initialize the visualizer bases
Expand All @@ -144,6 +149,7 @@ def __init__(
stack=stack,
colors=colors,
colormap=colormap,
top_n=top_n
)

def fit(self, X, y=None, **kwargs):
Expand Down Expand Up @@ -218,12 +224,30 @@ def fit(self, X, y=None, **kwargs):
else:
self.features_ = np.array(self.labels)

if self.top_n and self.top_n > self.features_.shape[0]:
raise YellowbrickValueError(
"top_n '{}' cannot be greater than the number of "
"features '{}'".format(self.top_n, self.features_.shape[0])
)

# Sort the features and their importances
if self.stack:
sort_idx = np.argsort(np.mean(self.feature_importances_, 0))
self.features_ = self.features_[sort_idx]
self.feature_importances_ = self.feature_importances_[:, sort_idx]
else:
if self.top_n: # Keep only the top/bottom N examples by magnitude
abs_sort_idx = np.argsort(np.absolute(self.feature_importances_))
if self.top_n > 0:
abs_sort_idx = abs_sort_idx[-self.top_n:]
else:
abs_sort_idx = abs_sort_idx[:-self.top_n]

# Filter features to only those top N
self.features_ = self.features_[abs_sort_idx]
self.feature_importances_ = self.feature_importances_[abs_sort_idx]

# Sort features once more to respect position for negative numbers
sort_idx = np.argsort(self.feature_importances_)
self.features_ = self.features_[sort_idx]
self.feature_importances_ = self.feature_importances_[sort_idx]
Expand Down Expand Up @@ -365,6 +389,7 @@ def feature_importances(
colors=None,
colormap=None,
is_fitted="auto",
top_n=None,
show=True,
**kwargs
):
Expand Down Expand Up @@ -431,6 +456,10 @@ def feature_importances(
call ``plt.savefig`` from this signature, nor ``clear_figure``. If False, simply
calls ``finalize()``
top_n : int, default=None
Display only the top N results with a positive integer, or the bottom N
results with a negative integer. If None or 0, all results are shown.
kwargs : dict
Keyword arguments that are passed to the base class and may influence
the visualization as defined in other Visualizers.
Expand All @@ -452,6 +481,7 @@ def feature_importances(
colors=colors,
colormap=colormap,
is_fitted=is_fitted,
top_n=top_n,
**kwargs
)

Expand Down

0 comments on commit 9ce2b8f

Please sign in to comment.