New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add elbow detection using the "kneedle" method to Elbow Visualizer #813
Changes from 9 commits
b5fd21c
ed7f923
a35f1d4
7b341cf
d6d1cb1
c90c4fb
ce30b20
707112c
33cc324
8e680f2
2be56ef
df7261d
3d95348
182b34f
0ac5556
cc726a6
a8b64fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
|
||
from .base import ClusteringScoreVisualizer | ||
from ..exceptions import YellowbrickValueError | ||
from ..utils import KneeLocator | ||
|
||
from sklearn.metrics import silhouette_score | ||
from sklearn.metrics import calinski_harabaz_score | ||
|
@@ -170,6 +171,9 @@ class KElbowVisualizer(ClusteringScoreVisualizer): | |
Display the fitting time per k to evaluate the amount of time required | ||
to train the clustering model. | ||
|
||
knee : bool, default=True | ||
Display the vertical line corresponding to the optimal value of k. | ||
|
||
kwargs : dict | ||
Keyword arguments that are passed to the base class and may influence | ||
the visualization as defined in other Visualizers. | ||
|
@@ -206,7 +210,7 @@ class KElbowVisualizer(ClusteringScoreVisualizer): | |
""" | ||
|
||
def __init__(self, model, ax=None, k=10, | ||
metric="distortion", timings=True, **kwargs): | ||
metric="distortion", timings=True, knee=True, **kwargs): | ||
super(KElbowVisualizer, self).__init__(model, ax=ax, **kwargs) | ||
|
||
# Get the scoring method | ||
|
@@ -219,6 +223,7 @@ def __init__(self, model, ax=None, k=10, | |
# Store the arguments | ||
self.scoring_metric = KELBOW_SCOREMAP[metric] | ||
self.timings = timings | ||
self.knee=knee | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do change the name of the argument, this should also be updated to match. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah! Sure. |
||
|
||
# Convert K into a tuple argument if an integer | ||
if isinstance(k, int): | ||
|
@@ -247,6 +252,9 @@ def fit(self, X, y=None, **kwargs): | |
|
||
self.k_scores_ = [] | ||
self.k_timers_ = [] | ||
self.kneedle=None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bbengfort For the moment I don't think there's any need to include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, also after more consideration; the |
||
self.knee_value=None | ||
self.score=None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great, thank you so much for storing these on the visualizer! I think users will be interested in directly accessing them. I've got two requests:
Finally (more on this later), these properties should only exist iff There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bbengfort Thanks so much for your kind words.
And as suggested it makes sense for these properties to exist iff One more question...I am not sure what do you mean by "learned" attributes , so will There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I appreciate it! The learned attributes are more a scikit-learn thing than a Yellowbrick thing. If you're interested in learning more, check out the sklearn developer guide. I guess they call them "estimated attributes" there. A learned/estimated attribute is any data that is created when Because you're making the |
||
|
||
for k in self.k_values_: | ||
# Compute the start time for each model | ||
|
@@ -262,6 +270,9 @@ def fit(self, X, y=None, **kwargs): | |
self.scoring_metric(X, self.estimator.labels_) | ||
) | ||
|
||
self.kneedle=KneeLocator(self.k_values_,self.k_scores_,curve='convex',direction='decreasing') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit concerned about the This leads me to believe that we need to do one of the following:
Potentially the second bullet point is the easiest option, and we can go that route and open an issue to explore this further. What do you think? Are there any other options? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bbengfort I was also concerned about the same , it didn't feel right to me to pass these two same parameters all the time. Regarding first bullet point , I think it would not be convenient for the users (new to ML) to specify the expected curve and direction. I think for now 2nd point is the most viable option given we have only three metrics for now and for two metrics if self.locate_elbow:
if self.metric == 'distortion':
self.curve_shape = 'convex'
self.curve_direction = 'decreasing'
elif self.metric=='silhouette' or self.metric=='calinski_harabaz':
self.curve_shape = 'concave'
self.curve_direction = 'increasing'
self.elbow_locator = KneeLocator(self.k_values_,self.k_scores_,curve=self.curve_shape,direction=self.curve_direction)
self.elbow_value_ = self.elbow_locator.knee
if self.elbow_value!=None:
self.elbow_score_ = self.k_scores_[self.k_values_.index(self.elbow_value_)] Where |
||
self.knee_value=self.kneedle.find_knee()[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After taking a look at the Also below you use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bbengfort Thanks for such in-depth review pointing out these small mistakes. Greatly appreciated.I agree |
||
self.score=self.k_scores_[self.k_values_.index(self.knee_value)] | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.draw() | ||
|
||
return self | ||
|
@@ -271,8 +282,10 @@ def draw(self): | |
Draw the elbow curve for the specified scores and values of K. | ||
""" | ||
# Plot the silhouette score against k | ||
self.ax.plot(self.k_values_, self.k_scores_, marker="D", label="score") | ||
|
||
self.ax.plot(self.k_values_, self.k_scores_, marker="D") | ||
if self.knee: | ||
self.ax.axvline(self.knee_value,c='black',linestyle='--') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the color, would you mind using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll use that. |
||
self.ax.legend(['Score={}'.format(round(self.score,3)),'Optimal k={}'.format(self.knee_value)],loc='best') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To create a text-only legend, you could use elbow_label = "$elbow at k={}, score={0.2f}$".format(self.elbow_value_, self.elbow_score_)"
self.ax.axvline(self.elbow_value_, c=LINE_COLOR, linestyle="--", label=elbow_label) Then later, in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note also the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done!. I'll change it acc to above suggestion. |
||
# If we're going to plot the timings, create a twinx axis | ||
if self.timings: | ||
self.axes = [self.ax, self.ax.twinx()] | ||
|
@@ -281,6 +294,7 @@ def draw(self): | |
c='g', marker="o", linestyle="--", alpha=0.75, | ||
) | ||
|
||
|
||
return self.ax | ||
|
||
def finalize(self): | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,3 +22,4 @@ | |
|
||
from .helpers import * | ||
from .types import * | ||
from .kneed import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
# yellowbrick.utils.kneed | ||
# A port of the knee-point detection package, kneed. | ||
# | ||
# Author: Kevin Arvai | ||
# Author: Pradeep Singh | ||
# Created: Mon Apr 15 09:43:18 2019 -0400 | ||
# | ||
# Copyright (C) 2017 Kevin Arvai | ||
# All rights reserved. | ||
# Redistribution and use in source and binary forms, with or without modification, | ||
# are permitted provided that the following conditions are met: | ||
# | ||
# 1. Redistributions of source code must retain the above copyright notice, this list | ||
# of conditions and the following disclaimer. | ||
# | ||
# 2. Redistributions in binary form must reproduce the above copyright notice, this | ||
# list of conditions and the following disclaimer in the documentation and/or other | ||
# materials provided with the distribution. | ||
# | ||
# 3. Neither the name of the copyright holder nor the names of its contributors may | ||
# be used to endorse or promote products derived from this software without specific | ||
# prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR | ||
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS | ||
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | ||
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING | ||
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN | ||
# IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
# | ||
# ID: kneed.py [] pswaldia@no-reply.github.com $ | ||
|
||
""" | ||
This package contains a port of the knee-point detection package, kneed, by | ||
Kevin Arvai and hosted at https://github.com/arvkevi/kneed. This port is maintained | ||
with permission by the Yellowbrick contributors. | ||
""" | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import numpy as np | ||
from scipy import interpolate | ||
from scipy.signal import argrelextrema | ||
import warnings | ||
|
||
|
||
class KneeLocator(object): | ||
|
||
def __init__(self, x, y, S=1.0, curve='concave', direction='increasing'): | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Once instantiated, this class attempts to find the point of maximum | ||
curvature on a line. The knee is accessible via the `.knee` attribute. | ||
:param x: x values. | ||
:type x: list or array. | ||
:param y: y values. | ||
:type y: list or array. | ||
:param S: Sensitivity, original paper suggests default of 1.0 | ||
:type S: float | ||
:param curve: If 'concave', algorithm will detect knees. If 'convex', it | ||
will detect elbows. | ||
:type curve: string | ||
:param direction: one of {"increasing", "decreasing"} | ||
:type direction: string | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to change the docstring to our docstring format to make sure it's documented correctly? If this is too much effort - no worries, I can always do it. Basically, this would involve:
Also, we should probably add this to the documentation by including it as an RST file under the utils API! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was about to do the same thing but thought would be great to get a feedback from you. I will move that docstring as directed. And will add to the documentation. Thanks. |
||
""" | ||
# Step 0: Raw Input | ||
self.x = x | ||
self.y = y | ||
self.curve = curve | ||
self.direction = direction | ||
self.N = len(self.x) | ||
self.S = S | ||
|
||
# Step 1: fit a smooth line | ||
uspline = interpolate.interp1d(self.x, self.y) | ||
self.Ds_x = np.linspace(np.min(self.x), np.max(self.x), self.N) | ||
self.Ds_y = uspline(self.Ds_x) | ||
|
||
# Step 2: normalize values | ||
self.xsn = self.__normalize(self.Ds_x) | ||
self.ysn = self.__normalize(self.Ds_y) | ||
|
||
# Step 3: Calculate difference curve | ||
self.xd = self.xsn | ||
if self.curve == 'convex' and direction == 'decreasing': | ||
self.yd = self.ysn + self.xsn | ||
self.yd = 1 - self.yd | ||
elif self.curve == 'concave' and direction == 'decreasing': | ||
self.yd = self.ysn + self.xsn | ||
elif self.curve == 'concave' and direction == 'increasing': | ||
self.yd = self.ysn - self.xsn | ||
if self.curve == 'convex' and direction == 'increasing': | ||
self.yd = abs(self.ysn - self.xsn) | ||
|
||
# Step 4: Identify local maxima/minima | ||
# local maxima | ||
self.xmx_idx = argrelextrema(self.yd, np.greater)[0] | ||
self.xmx = self.xd[self.xmx_idx] | ||
self.ymx = self.yd[self.xmx_idx] | ||
|
||
# local minima | ||
self.xmn_idx = argrelextrema(self.yd, np.less)[0] | ||
self.xmn = self.xd[self.xmn_idx] | ||
self.ymn = self.yd[self.xmn_idx] | ||
|
||
# Step 5: Calculate thresholds | ||
self.Tmx = self.__threshold(self.ymx) | ||
|
||
# Step 6: find knee | ||
self.knee, self.norm_knee, self.knee_x = self.find_knee() | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@staticmethod | ||
def __normalize(a): | ||
"""normalize an array | ||
:param a: The array to normalize | ||
:type a: array | ||
""" | ||
return (a - min(a)) / (max(a) - min(a)) | ||
|
||
def __threshold(self, ymx_i): | ||
"""Calculates the difference threshold for a | ||
given difference local maximum | ||
:param ymx_i: the normalized y value of a local maximum | ||
""" | ||
return ymx_i - (self.S * np.diff(self.xsn).mean()) | ||
|
||
def find_knee(self, ): | ||
"""This function finds and returns the knee value, the normalized knee | ||
value, and the x value where the knee is located. | ||
:returns: tuple(knee, norm_knee, knee_x) | ||
:rtype: (float, float, int) | ||
) | ||
""" | ||
if not self.xmx_idx.size: | ||
warnings.warn("No local maxima found in the distance curve\n" | ||
"The line is probably not polynomial, try plotting\n" | ||
"the distance curve with plt.plot(knee.xd, knee.yd)\n" | ||
"Also check that you aren't mistakenly setting the curve argument", RuntimeWarning) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, I think we may have to do something about this warning, otherwise, it could be very confusing to our users. First, it is probably preferable to raise an exception here - that way we can catch the exception in the If we do go with a warning, instead of a RuntimeWarning, could we please issue a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whew, this is very, very legacy... the warning started as a sanity check long ago. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are running into it because we can be a bit all over the place with convex/increasing or concave/decreasing depending on the metric we're using -- and if the clustering is terrible (e.g. bad features, wrong algorithm, no actual clusters) then things get really wild at that point. Any advice or thoughts you have would be welcome! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bbengfort I am trying to get familiar with the exceptions handling. Once done I'll accommodate these changes in a PR. |
||
return None, None, None | ||
|
||
mxmx_iter = np.arange(self.xmx_idx[0], len(self.xsn)) | ||
xmx_idx_iter = np.append(self.xmx_idx, len(self.xsn)) | ||
|
||
knee_, norm_knee_, knee_x = 0.0, 0.0, None | ||
for mxmx_i, mxmx in enumerate(xmx_idx_iter): | ||
# stopping criteria for exhasuting array | ||
if mxmx_i == len(xmx_idx_iter) - 1: | ||
break | ||
# indices between maxima/minima | ||
idxs = (mxmx_iter > xmx_idx_iter[mxmx_i]) * \ | ||
(mxmx_iter < xmx_idx_iter[mxmx_i + 1]) | ||
between_local_mx = mxmx_iter[np.where(idxs)] | ||
|
||
for j in between_local_mx: | ||
if j in self.xmn_idx: | ||
# reached a minima, x indices are unique | ||
# only need to check if j is a min | ||
if self.yd[j + 1] > self.yd[j]: | ||
self.Tmx[mxmx_i] = 0 | ||
knee_x = None # reset x where yd crossed Tmx | ||
elif self.yd[j + 1] <= self.yd[j]: | ||
warnings.warn("If this is a minima, " | ||
"how would you ever get here:", RuntimeWarning) | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.yd[j] < self.Tmx[mxmx_i] or self.Tmx[mxmx_i] < 0: | ||
# declare a knee | ||
if not knee_x: | ||
knee_x = j | ||
knee_ = self.x[self.xmx_idx[mxmx_i]] | ||
norm_knee_ = self.xsn[self.xmx_idx[mxmx_i]] | ||
return knee_, norm_knee_, knee_x | ||
|
||
def plot_knee_normalized(self, ): | ||
"""Plot the normalized curve, the distance curve (xd, ysn) and the | ||
knee, if it exists. | ||
""" | ||
import matplotlib.pyplot as plt | ||
|
||
plt.figure(figsize=(8, 8)) | ||
plt.plot(self.xsn, self.ysn) | ||
plt.plot(self.xd, self.yd, 'r') | ||
plt.xticks(np.arange(min(self.xsn), max(self.xsn) + 0.1, 0.1)) | ||
plt.yticks(np.arange(min(self.xd), max(self.ysn) + 0.1, 0.1)) | ||
|
||
plt.vlines(self.norm_knee, plt.ylim()[0], plt.ylim()[1]) | ||
|
||
def plot_knee(self, ): | ||
"""Plot the curve and the knee, if it exists""" | ||
import matplotlib.pyplot as plt | ||
|
||
plt.figure(figsize=(8, 8)) | ||
plt.plot(self.x, self.y) | ||
plt.vlines(self.knee, plt.ylim()[0], plt.ylim()[1]) | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Niceties for users working with elbows rather than knees | ||
|
||
@property | ||
def elbow(self): | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.knee | ||
|
||
@property | ||
def norm_elbow(self): | ||
return self.norm_knee | ||
|
||
@property | ||
def elbow_x(self): | ||
return self.knee_x | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hehe, I do think it's funny to have a
knee
parameter in an "elbow" visualizer. I really do want to keep calling this "knee", but I'm worried it might be a bit confusing particularly to students. May I propose the following:Does that seem a bit more understandable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes , this is definitely a good name and a nice description too. I'll replace that with the suggested one.
Thanks.