Skip to content

Commit

Permalink
Handle fit kwargs in KElbowVisualizer quick method
Browse files Browse the repository at this point in the history
  • Loading branch information
Express50 committed Apr 18, 2020
1 parent ef1647d commit d298bc2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
13 changes: 11 additions & 2 deletions yellowbrick/cluster/elbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import pairwise_distances

from yellowbrick.utils import KneeLocator
from yellowbrick.utils import KneeLocator, get_param_names
from yellowbrick.style.palettes import LINE_COLOR
from yellowbrick.cluster.base import ClusteringScoreVisualizer
from yellowbrick.exceptions import YellowbrickValueError, YellowbrickWarning
Expand Down Expand Up @@ -487,6 +487,15 @@ def kelbow_visualizer(
viz : KElbowVisualizer
The kelbow visualizer, fitted and finalized.
"""
klass = type(model)

# figure out which kwargs correspond to fit method
fit_params = get_param_names(klass.fit)

fit_kwargs = {
key: kwargs.pop(key) for key in fit_params if key in kwargs
}

oz = KElbow(
model,
ax=ax,
Expand All @@ -496,7 +505,7 @@ def kelbow_visualizer(
locate_elbow=locate_elbow,
**kwargs
)
oz.fit(X, y)
oz.fit(X, y, **fit_kwargs)

if show:
oz.show()
Expand Down
23 changes: 23 additions & 0 deletions yellowbrick/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
##########################################################################

import re
import inspect
import sklearn
import numpy as np

Expand Down Expand Up @@ -185,6 +186,28 @@ def is_monotonic(a, increasing=True):
return np.all(a[1:] <= a[:-1], axis=0)


def get_param_names(method):
"""
Returns a list of keyword-only parameter names that may be
passed into method.
Parameters
----------
method : function
The method for which to return keyword-only parameters.
Returns
-------
parameters : list
A list of keyword-only parameter names for method.
"""
signature = inspect.signature(method)
parameters = [p for p in signature.parameters.values()
if p.name != 'self' and p.kind != p.VAR_KEYWORD]

return sorted([p.name for p in parameters])


##########################################################################
## Numeric Computations
##########################################################################
Expand Down

0 comments on commit d298bc2

Please sign in to comment.