Navigation Menu

Skip to content

Commit

Permalink
Better default value for data_checks in AutoSearchBase (#892)
Browse files Browse the repository at this point in the history
* data_checks can now be a list, str, DataChecks, or None.

* Updating changelog for PR 890

* data_checks can now be a list, str, DataChecks, or None.

* Updating changelog for PR 890

* Updating changelog for PR 892.

* DataChecks now validate input list to ensure every element is a DataCheck.

* Fixing lint issues related to data checks in automl search

* DataChecks now checks if the input parameter is a list.
  • Loading branch information
freddyaboulton committed Jun 29, 2020
1 parent d83c70f commit 96aecd8
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 29 deletions.
4 changes: 3 additions & 1 deletion docs/source/changelog.rst
Expand Up @@ -21,13 +21,14 @@ Changelog
* Added SelectColumns transformer :pr:`873`
* Added ability to evaluate additional pipelines for automl search :pr:`874`
* Added `default_parameters` class property to components and pipelines :pr:`879`
* Added better support for disabling data checks in automl search :pr:`892`
* Added ability to save and load AutoML objects to file :pr:`888`
* Fixes
* Fixed bug where SimpleImputer cannot handle dropped columns :pr:`846`
* Fixed bug where PerColumnImputer cannot handle dropped columns :pr:`855`
* Enforce requirement that builtin components save all inputted values in their parameters dict :pr:`847`
* Don't list base classes in `all_components` output :pr:`847`
* Standardize all components to output pandas datastructures, and accept either pandas or numpy :pr:`853`
* Standardize all components to output pandas data structures, and accept either pandas or numpy :pr:`853`
* Changes
* Update `all_pipelines` and `all_components` to try initializing pipelines/components, and on failure exclude them :pr:`849`
* Refactor `handle_components` to `handle_components_class`, standardize to `ComponentBase` subclass instead of instance :pr:`850`
Expand All @@ -48,6 +49,7 @@ Changelog
* Pipelines' static ``component_graph`` field must contain either ``ComponentBase`` subclasses or ``str``, instead of ``ComponentBase`` subclass instances :pr:`850`
* Rename ``handle_component`` to ``handle_component_class``. Now standardizes to ``ComponentBase`` subclasses instead of ``ComponentBase`` subclass instances :pr:`850`
* Pipelines' and classifiers' `feature_importances` is renamed `feature_importance`, `graph_feature_importances` is renamed `graph_feature_importance` :pr:`883`
* Passing data_checks=None to automl search will not perform any data checks as opposed to default checks. :pr:`892`
* Pipelines to search for in AutoML are now determined automatically, rather than using the statically-defined pipeline classes. :pr:`870`


Expand Down
46 changes: 36 additions & 10 deletions evalml/automl/automl_search.py
Expand Up @@ -12,7 +12,7 @@
from .pipeline_search_plots import PipelineSearchPlots

from evalml.automl.automl_algorithm import IterativeAlgorithm
from evalml.data_checks import DataChecks, DefaultDataChecks
from evalml.data_checks import DataChecks, DefaultDataChecks, EmptyDataChecks
from evalml.data_checks.data_check_message_type import DataCheckMessageType
from evalml.objectives import get_objective, get_objectives
from evalml.pipelines import (
Expand Down Expand Up @@ -225,7 +225,35 @@ def _get_funct_name(function):

return search_desc + rankings_desc

def search(self, X, y, data_checks=None, feature_types=None, raise_errors=True, show_iteration_plot=True):
@staticmethod
def _validate_data_checks(data_checks):
"""Validate data_checks parameter.
Arguments:
data_checks (DataChecks, list(Datacheck), str, None): Input to validate. If not of the right type,
raise an exception.
Returns:
An instance of DataChecks used to perform checks before search.
"""
if isinstance(data_checks, DataChecks):
return data_checks
elif isinstance(data_checks, list):
return DataChecks(data_checks)
elif isinstance(data_checks, str):
if data_checks == "auto":
return DefaultDataChecks()
elif data_checks == "disabled":
return EmptyDataChecks()
else:
raise ValueError("If data_checks is a string, it must be either 'auto' or 'disabled'. "
f"Received '{data_checks}'.")
elif data_checks is None:
return EmptyDataChecks()
else:
return DataChecks(data_checks)

def search(self, X, y, data_checks="auto", feature_types=None, raise_errors=True, show_iteration_plot=True):
"""Find best classifier
Arguments:
Expand All @@ -241,10 +269,12 @@ def search(self, X, y, data_checks=None, feature_types=None, raise_errors=True,
show_iteration_plot (boolean, True): Shows an iteration vs. score plot in Jupyter notebook.
Disabled by default in non-Jupyter enviroments.
data_checks (DataChecks, None): A collection of data checks to run before searching for the best classifier. If data checks produce any errors, an exception will be thrown before the search begins. If None, uses DefaultDataChecks. Defaults to None.
data_checks (DataChecks, list(Datacheck), str, None): A collection of data checks to run before
automl search. If data checks produce any errors, an exception will be thrown before the
search begins. If "disabled" or None, no data checks will be done.
If set to "auto", DefaultDataChecks will be done. Default value is set to "auto".
Returns:
self
"""
# don't show iteration plot outside of a jupyter notebook
Expand All @@ -261,13 +291,9 @@ def search(self, X, y, data_checks=None, feature_types=None, raise_errors=True,
if not isinstance(y, pd.Series):
y = pd.Series(y)

if data_checks is None:
data_checks = DefaultDataChecks()

if not isinstance(data_checks, DataChecks):
raise ValueError("data_checks parameter must be a DataChecks object!")

data_checks = self._validate_data_checks(data_checks)
data_check_results = data_checks.validate(X, y)

if len(data_check_results) > 0:
self._data_check_results = data_check_results
for message in self._data_check_results:
Expand Down
8 changes: 8 additions & 0 deletions evalml/data_checks/data_checks.py
@@ -1,3 +1,6 @@
from .data_check import DataCheck


class DataChecks:
"""A collection of data checks."""

Expand All @@ -8,6 +11,11 @@ def __init__(self, data_checks=None):
Arguments:
data_checks (list (DataCheck)): list of DataCheck objects
"""
if not isinstance(data_checks, list):
raise ValueError(f"Parameter data_checks must be a list. Received {type(data_checks).__name__}.")
if not all(isinstance(check, DataCheck) for check in data_checks):
raise ValueError("All elements of parameter data_checks must be an instance of DataCheck.")

self.data_checks = data_checks

def validate(self, X, y=None):
Expand Down
61 changes: 43 additions & 18 deletions evalml/tests/automl_tests/test_automl.py
Expand Up @@ -11,8 +11,7 @@
DataCheck,
DataCheckError,
DataChecks,
DataCheckWarning,
EmptyDataChecks
DataCheckWarning
)
from evalml.model_family import ModelFamily
from evalml.pipelines import BinaryClassificationPipeline
Expand Down Expand Up @@ -228,15 +227,26 @@ def test_automl_data_check_results_is_none_before_search():

@patch('evalml.pipelines.BinaryClassificationPipeline.score')
@patch('evalml.pipelines.BinaryClassificationPipeline.fit')
def test_automl_empty_data_checks(mock_fit, mock_score, X_y):
X, y = X_y
def test_automl_empty_data_checks(mock_fit, mock_score):
X = pd.DataFrame({"feature1": [1, 2, 3],
"feature2": [None, None, None]})
y = pd.Series([1, 1, 1])

mock_score.return_value = {'Log Loss Binary': 1.0}
automl = AutoMLSearch(problem_type='binary', max_pipelines=1)
automl.search(X, y, data_checks=EmptyDataChecks())

automl = AutoMLSearch(problem_type="binary", max_pipelines=1)

automl.search(X, y, data_checks=[])
assert automl.data_check_results is None
mock_fit.assert_called()
mock_score.assert_called()

automl.search(X, y, data_checks="disabled")
assert automl.data_check_results is None

automl.search(X, y, data_checks=None)
assert automl.data_check_results is None


@patch('evalml.data_checks.DefaultDataChecks.validate')
@patch('evalml.pipelines.BinaryClassificationPipeline.score')
Expand All @@ -255,31 +265,46 @@ def test_automl_default_data_checks(mock_fit, mock_score, mock_validate, X_y, ca
mock_validate.assert_called()


def test_automl_data_checks_raises_error(caplog):
class MockDataCheckErrorAndWarning(DataCheck):
def validate(self, X, y):
return [DataCheckError("error one", self.name), DataCheckWarning("warning one", self.name)]


@pytest.mark.parametrize("data_checks",
[[MockDataCheckErrorAndWarning()],
DataChecks([MockDataCheckErrorAndWarning()])])
@patch('evalml.pipelines.BinaryClassificationPipeline.score')
@patch('evalml.pipelines.BinaryClassificationPipeline.fit')
def test_automl_data_checks_raises_error(mock_fit, mock_score, data_checks, caplog):
X = pd.DataFrame()
y = pd.Series()

class MockDataCheckErrorAndWarning(DataCheck):
def validate(self, X, y):
return [DataCheckError("error one", self.name), DataCheckWarning("warning one", self.name)]

data_checks = DataChecks(data_checks=[MockDataCheckErrorAndWarning()])
automl = AutoMLSearch(problem_type='binary', max_pipelines=1)
automl = AutoMLSearch(problem_type="binary", max_pipelines=1)

with pytest.raises(ValueError, match="Data checks raised"):
automl.search(X, y, data_checks=data_checks)

out = caplog.text
assert "error one" in out
assert "warning one" in out
assert automl.data_check_results == data_checks.validate(X, y)
assert automl.data_check_results == MockDataCheckErrorAndWarning().validate(X, y)


def test_automl_not_data_check_object():
def test_automl_bad_data_check_parameter_type():
X = pd.DataFrame()
y = pd.Series()
automl = AutoMLSearch(problem_type='binary', max_pipelines=1)
with pytest.raises(ValueError, match="data_checks parameter must be a DataChecks object!"):

automl = AutoMLSearch(problem_type="binary", max_pipelines=1)

with pytest.raises(ValueError, match="Parameter data_checks must be a list. Received int."):
automl.search(X, y, data_checks=1)
with pytest.raises(ValueError, match="All elements of parameter data_checks must be an instance of DataCheck."):
automl.search(X, y, data_checks=[1])
with pytest.raises(ValueError, match="If data_checks is a string, it must be either 'auto' or 'disabled'. "
"Received 'default'."):
automl.search(X, y, data_checks="default")
with pytest.raises(ValueError, match="All elements of parameter data_checks must be an instance of DataCheck."):
automl.search(X, y, data_checks=[DataChecks([]), 1])


def test_automl_str_no_param_search():
Expand Down Expand Up @@ -424,7 +449,7 @@ def test_obj_matches_problem_type(X_y):
X, y = X_y
with pytest.raises(ValueError, match="is not compatible with a"):
auto = AutoMLSearch(problem_type='binary', objective='R2')
auto.search(X, y, data_checks=EmptyDataChecks())
auto.search(X, y, data_checks=[])


def test_init_problem_type_error():
Expand Down

0 comments on commit 96aecd8

Please sign in to comment.