New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add HighVarianceCVDataCheck #1254
Changes from all commits
88afbc7
6c8ddd8
14b7e90
c00d229
15e05e3
2dc7317
8e93701
05f8400
94c8ca7
30a9c5a
88f37e9
ed6a579
3f80da9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
import copy | ||
import time | ||
import warnings | ||
from collections import OrderedDict, defaultdict | ||
|
||
import cloudpickle | ||
|
@@ -22,7 +21,8 @@ | |
AutoMLDataChecks, | ||
DataChecks, | ||
DefaultDataChecks, | ||
EmptyDataChecks | ||
EmptyDataChecks, | ||
HighVarianceCVDataCheck | ||
) | ||
from evalml.data_checks.data_check_message_type import DataCheckMessageType | ||
from evalml.exceptions import ( | ||
|
@@ -395,7 +395,7 @@ def search(self, X, y, data_checks="auto", feature_types=None, show_iteration_pl | |
data_checks = self._validate_data_checks(data_checks) | ||
data_check_results = data_checks.validate(X, y) | ||
|
||
if len(data_check_results) > 0: | ||
if data_check_results: | ||
self._data_check_results = data_check_results | ||
for message in self._data_check_results: | ||
if message.message_type == DataCheckMessageType.WARNING: | ||
|
@@ -694,16 +694,18 @@ def _add_result(self, trained_pipeline, parameters, training_time, cv_data, cv_s | |
self._baseline_cv_scores.get(obj_name, np.nan)) | ||
percent_better_than_baseline[obj_name] = percent_better | ||
|
||
# calculate high_variance_cv | ||
# if the coefficient of variance is greater than .2 | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter('ignore') | ||
high_variance_cv = (cv_scores.std() / cv_scores.mean()) > .2 | ||
|
||
pipeline_name = trained_pipeline.name | ||
pipeline_summary = trained_pipeline.summary | ||
pipeline_id = len(self._results['pipeline_results']) | ||
|
||
high_variance_cv_check = HighVarianceCVDataCheck(threshold=0.2) | ||
high_variance_cv_check_results = high_variance_cv_check.validate(pipeline_name=pipeline_name, cv_scores=cv_scores) | ||
high_variance_cv = False | ||
|
||
if high_variance_cv_check_results: | ||
logger.warning(high_variance_cv_check_results[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jeremyliweishih does this show up in the console in a well-formatted way? I've noticed that |
||
high_variance_cv = True | ||
|
||
self._results['pipeline_results'][pipeline_id] = { | ||
"id": pipeline_id, | ||
"pipeline_name": pipeline_name, | ||
|
@@ -785,10 +787,6 @@ def describe_pipeline(self, pipeline_id, return_dict=False): | |
logger.info("Total training time (including CV): %.1f seconds" % pipeline_results["training_time"]) | ||
log_subtitle(logger, "Cross Validation", underline="-") | ||
|
||
if pipeline_results["high_variance_cv"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved the logging behavior from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this makes sense! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed! |
||
logger.warning("High variance within cross validation scores. " + | ||
"Model may not perform as estimated on unseen data.") | ||
|
||
all_objective_scores = [fold["all_objective_scores"] for fold in pipeline_results["cv_data"]] | ||
all_objective_scores = pd.DataFrame(all_objective_scores) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pandas as pd | ||
|
||
from .data_check import DataCheck | ||
from .data_check_message import DataCheckWarning | ||
|
||
|
||
class HighVarianceCVDataCheck(DataCheck): | ||
"""Checks if the variance between folds in cross-validation is higher than an acceptable threshhold.""" | ||
|
||
def __init__(self, threshold=0.2): | ||
"""Check if there is higher variance among cross-validation results. | ||
|
||
Arguments: | ||
threshold (float): The minimum threshold allowed for high variance before a warning is raised. | ||
Defaults to 0.2 and must be above 0. | ||
""" | ||
if threshold < 0: | ||
raise ValueError(f"Provided threshold {threshold} needs to be greater than 0.") | ||
self.threshold = threshold | ||
|
||
def validate(self, pipeline_name, cv_scores): | ||
"""Checks cross-validation scores and issues an warning if variance is higher than specified threshhold. | ||
|
||
Arguments: | ||
pipeline_name (str): name of pipeline that produced cv_scores | ||
cv_scores (pd.Series, np.array, list): list of scores of each cross-validation fold | ||
|
||
Returns: | ||
list (DataCheckWarning): list with DataCheckWarnings if imbalance in classes is less than the threshold. | ||
|
||
Example: | ||
>>> cv_scores = pd.Series([0, 1, 1, 1]) | ||
>>> check = HighVarianceCVDataCheck(threshold=0.10) | ||
>>> assert check.validate("LogisticRegressionPipeline", cv_scores) == [DataCheckWarning("High coefficient of variation (cv >= 0.1) within cross validation scores. LogisticRegressionPipeline may not perform as estimated on unseen data.", "HighVarianceCVDataCheck")] | ||
""" | ||
if not isinstance(cv_scores, pd.Series): | ||
cv_scores = pd.Series(cv_scores) | ||
|
||
messages = [] | ||
high_variance_cv = abs(cv_scores.std() / cv_scores.mean()) > self.threshold | ||
# if there are items that occur less than the threshold, add them to the list of messages | ||
if high_variance_cv: | ||
warning_msg = f"High coefficient of variation (cv >= {self.threshold}) within cross validation scores. {pipeline_name} may not perform as estimated on unseen data." | ||
messages.append(DataCheckWarning(warning_msg, self.name)) | ||
return messages |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from evalml.data_checks import DataCheckWarning, HighVarianceCVDataCheck | ||
|
||
|
||
def test_high_variance_cv_data_check_invalid_threshold(): | ||
with pytest.raises(ValueError, match="needs to be greater than 0."): | ||
HighVarianceCVDataCheck(threshold=-0.1).validate(pipeline_name='LogisticRegressionPipeline', cv_scores=pd.Series([0, 1, 1])) | ||
|
||
|
||
def test_high_variance_cv_data_check(): | ||
high_variance_cv = HighVarianceCVDataCheck() | ||
|
||
assert high_variance_cv.validate(pipeline_name='LogisticRegressionPipeline', cv_scores=[1, 1, 1]) == [] | ||
assert high_variance_cv.validate(pipeline_name='LogisticRegressionPipeline', cv_scores=pd.Series([1, 1, 1])) == [] | ||
assert high_variance_cv.validate(pipeline_name='LogisticRegressionPipeline', cv_scores=pd.Series([0, 1, 2, 3])) == [DataCheckWarning("High coefficient of variation (cv >= 0.2) within cross validation scores. LogisticRegressionPipeline may not perform as estimated on unseen data.", "HighVarianceCVDataCheck")] | ||
|
||
|
||
def test_high_variance_cv_data_check_empty_nan(): | ||
high_variance_cv = HighVarianceCVDataCheck() | ||
assert high_variance_cv.validate(pipeline_name='LogisticRegressionPipeline', cv_scores=pd.Series([0, 1, np.nan, np.nan])) == [DataCheckWarning("High coefficient of variation (cv >= 0.2) within cross validation scores. LogisticRegressionPipeline may not perform as estimated on unseen data.", "HighVarianceCVDataCheck")] | ||
|
||
|
||
def test_high_variance_cv_data_check_negative(): | ||
high_variance_cv = HighVarianceCVDataCheck() | ||
assert high_variance_cv.validate(pipeline_name='LogisticRegressionPipeline', cv_scores=pd.Series([0, -1, -1, -1])) == [DataCheckWarning("High coefficient of variation (cv >= 0.2) within cross validation scores. LogisticRegressionPipeline may not perform as estimated on unseen data.", "HighVarianceCVDataCheck")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should users be allowed to configure this threshold now that this is a DataCheck and we let users configure other data checks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good thought! Perhaps we can turn this on and off depending on if
data_checks = auto
vs.data_checks = disabled
. But I don't think it'll fit within the existing API for parameterizing data checks as all those data checks run before search is called whereas this check is called during search. I like the idea but we would need to think about what API changes to make toAutoMLSearch.search()
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 yep @freddyaboulton I agree this should be configurable/disable-able.
This PR is essentially porting existing behavior into a new API (data checks). I'll file an issue now to track making this configurable.