Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update data checks to return DataCheckResults object #1444

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f560617
init
angela97lin Nov 13, 2020
b6a66ac
docstrs
angela97lin Nov 13, 2020
29284c1
update class imbalance
angela97lin Nov 16, 2020
f50c267
more conversion
angela97lin Nov 16, 2020
86b92ad
docstr
angela97lin Nov 16, 2020
a3e4b8f
more updates
angela97lin Nov 17, 2020
8031cc2
Merge branch 'main' into 1325_data_checks_returns_dict
angela97lin Nov 17, 2020
ad8a15e
cleanup
angela97lin Nov 17, 2020
455dfdc
Merge branch '1325_data_checks_returns_dict' of github.com:alteryx/ev…
angela97lin Nov 17, 2020
cbd9f42
fix some tests for data checks
angela97lin Nov 17, 2020
e4e4b78
fix more tests
angela97lin Nov 17, 2020
aa101da
fix more tests
angela97lin Nov 17, 2020
61489de
test doctest
angela97lin Nov 17, 2020
82b3c3a
fix doctest
angela97lin Nov 17, 2020
ac7181f
fix no variance data check
angela97lin Nov 17, 2020
86273fe
fix data checks tests
angela97lin Nov 17, 2020
2b8e910
update notebook
angela97lin Nov 17, 2020
c4f12ab
Merge branch 'main' into 1325_data_checks_returns_dict
angela97lin Nov 18, 2020
b08447c
Merge branch 'main' into 1325_data_checks_returns_dict
angela97lin Nov 19, 2020
677edf1
update to use data check results class
angela97lin Nov 19, 2020
2a801f0
fix doctests
angela97lin Nov 19, 2020
1038edd
clear notebook outputs
angela97lin Nov 19, 2020
96ceaef
fix notebooks
angela97lin Nov 19, 2020
bc49716
fix doctests
angela97lin Nov 19, 2020
a30bd71
add equality test
angela97lin Nov 19, 2020
aaf52ea
fix docstr
angela97lin Nov 19, 2020
76aaaff
add to_json
angela97lin Nov 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ Release Notes
* Changes
* Changed ``OutliersDataCheck`` to return the list of columns, rather than rows, that contain outliers :pr:`1377`
* Simplified and cleaned output for Code Generation :pr:`1371`
* Reverted changes from :pr:`1337` :pr:`1409`
* Reverted changes from :pr:`1337` :pr:`1409``
* Updated data checks to return dictionary of warnings and errors instead of a list :pr:`1444`
* Documentation Changes
* Added description of CLA to contributing guide, updated description of draft PRs :pr:`1402`
* Updated documentation to include all data checks, ``DataChecks``, and usage of data checks in AutoML :pr:`1412`
Expand All @@ -40,6 +41,9 @@ Release Notes
**Breaking Changes**
* The ``top_k`` and ``top_k_features`` parameters in ``explain_predictions_*`` functions now return ``k`` features as opposed to ``2 * k`` features :pr:`1374`
* Renamed ``problem_type`` to ``problem_types`` in ``RegressionObjective``, ``BinaryClassificationObjective``, and ``MulticlassClassificationObjective`` :pr:`1319`
* Data checks now return a dictionary of warnings and errors instead of a list :pr:`1444`
angela97lin marked this conversation as resolved.
Show resolved Hide resolved



**v0.15.0 Oct. 29, 2020**
* Enhancements
Expand Down
87 changes: 63 additions & 24 deletions docs/source/user_guide/data_checks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from evalml.data_checks import DataCheckMessageType\n",
"\n",
"from evalml.data_checks import HighlyNullDataCheck\n",
"\n",
Expand All @@ -43,9 +44,13 @@
" [8, 6, np.nan]])\n",
"\n",
"null_check = HighlyNullDataCheck(pct_null_threshold=0.8)\n",
"results = null_check.validate(X)\n",
"\n",
"for message in null_check.validate(X):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand Down Expand Up @@ -84,9 +89,13 @@
" \"good col\":[0, 4, 1]})\n",
"y = pd.Series([1, 0, 1])\n",
"no_variance_data_check = NoVarianceDataCheck()\n",
"results = no_variance_data_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in no_variance_data_check.validate(X, y):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
angela97lin marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
Expand All @@ -110,9 +119,13 @@
"y = pd.Series([1, 0, 1])\n",
"\n",
"no_variance_data_check = NoVarianceDataCheck(count_nan_as_value=True)\n",
"results = no_variance_data_check.validate(X, y)\n",
"\n",
"for message in no_variance_data_check.validate(X, y):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -139,10 +152,15 @@
" [4, 4, 8, 3],\n",
" [9, 2, 7, 1]])\n",
"y = pd.Series([0, 1, 1, 1, 1])\n",
"\n",
"class_imbalance_check = ClassImbalanceDataCheck(threshold=0.25, num_cv_folds=4)\n",
"results = class_imbalance_check.validate(X, y)\n",
"\n",
"for message in class_imbalance_check.validate(X, y):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -167,8 +185,13 @@
"y = pd.Series([10, 42, 31, 51, 40])\n",
"\n",
"target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.8)\n",
"for message in target_leakage_check.validate(X, y):\n",
" print(message.message)"
"results = target_leakage_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -194,10 +217,15 @@
" \n",
"X = pd.DataFrame({})\n",
"y = pd.Series([0, 1, None, None])\n",
"\n",
"invalid_target_check = InvalidTargetDataCheck('binary')\n",
"results = invalid_target_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in invalid_target_check.validate(X, y):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -218,10 +246,15 @@
"from evalml.data_checks import IDColumnsDataCheck\n",
"\n",
"X = pd.DataFrame([[0, 53, 6325, 5],[1, 90, 6325, 10],[2, 90, 18, 20]], columns=['user_number', 'cost', 'revenue', 'id'])\n",
"\n",
"id_col_check = IDColumnsDataCheck(id_threshold=0.9)\n",
"results = id_col_check.validate(X, y)\n",
"\n",
"for message in id_col_check.validate(X):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand Down Expand Up @@ -282,9 +315,13 @@
"from evalml.data_checks import OutliersDataCheck\n",
"\n",
"outliers_check = OutliersDataCheck()\n",
"results = outliers_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in outliers_check.validate(X):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -310,12 +347,13 @@
"y = pd.Series([1, 0, 1])\n",
"\n",
"no_variance_data_check = NoVarianceDataCheck(count_nan_as_value=True)\n",
"results = no_variance_data_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in no_variance_data_check.validate(X, y):\n",
" if isinstance(message, DataCheckError):\n",
" print(\"ERROR:\", message.message)\n",
" elif isinstance(message, DataCheckWarning):\n",
" print(\"WARNING:\", message.message)"
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -329,7 +367,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If you would prefer to write your own data check, you can do so by extending the `DataCheck` class and implementing the `validate(self, X, y)` class method. Below, we've created a new `DataCheck`, `ZeroVarianceDataCheck`, which is similar to `NoVarianceDataCheck` defined in `EvalML`."
"If you would prefer to write your own data check, you can do so by extending the `DataCheck` class and implementing the `validate(self, X, y)` class method. Below, we've created a new `DataCheck`, `ZeroVarianceDataCheck`, which is similar to `NoVarianceDataCheck` defined in `EvalML`. The `validate(self, X, y)` method should return a dictionary with `DataCheckMessageType.WARNING` and `DataCheckMessageType.ERROR` as keys mapping to list of warnings and errors, respectively."
angela97lin marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
Expand All @@ -345,10 +383,11 @@
"\n",
"class ZeroVarianceDataCheck(DataCheck):\n",
" def validate(self, X, y):\n",
" messages = {DataCheckMessageType.WARNING: [], DataCheckMessageType.ERROR: []}\n",
" if not isinstance(X, pd.DataFrame):\n",
" X = pd.DataFrame(X)\n",
" warning_msg = \"Column '{}' has zero variance\"\n",
" return [DataCheckError(warning_msg.format(column), self.name) for column in X.columns if len(X[column].unique()) == 1]"
" messages[DataCheckMessageType.WARNING].extend([DataCheckError(warning_msg.format(column), self.name) for column in X.columns if len(X[column].unique()) == 1])"
]
},
{
Expand Down Expand Up @@ -458,4 +497,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
21 changes: 9 additions & 12 deletions evalml/automl/automl_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,13 @@ def search(self, X, y, data_checks="auto", show_iteration_plot=True):
self._set_data_split(X)

data_checks = self._validate_data_checks(data_checks)
data_check_results = data_checks.validate(X, y)
if data_check_results:
self._data_check_results = data_check_results
for message in self._data_check_results:
if message.message_type == DataCheckMessageType.WARNING:
logger.warning(message)
elif message.message_type == DataCheckMessageType.ERROR:
logger.error(message)
if any([message.message_type == DataCheckMessageType.ERROR for message in self._data_check_results]):
raise ValueError("Data checks raised some warnings and/or errors. Please see `self.data_check_results` for more information or pass data_checks='disabled' to search() to disable data checking.")
self._data_check_results = data_checks.validate(X, y)
for message in self._data_check_results[DataCheckMessageType.WARNING]:
logger.warning(message)
for message in self._data_check_results[DataCheckMessageType.ERROR]:
logger.error(message)
if self._data_check_results[DataCheckMessageType.ERROR]:
raise ValueError("Data checks raised some warnings and/or errors. Please see `self.data_check_results` for more information or pass data_checks='disabled' to search() to disable data checking.")
angela97lin marked this conversation as resolved.
Show resolved Hide resolved

if self.allowed_pipelines is None:
logger.info("Generating pipelines to search over...")
Expand Down Expand Up @@ -740,8 +737,8 @@ def _add_result(self, trained_pipeline, parameters, training_time, cv_data, cv_s
high_variance_cv_check_results = high_variance_cv_check.validate(pipeline_name=pipeline_name, cv_scores=cv_scores)
high_variance_cv = False

if high_variance_cv_check_results:
logger.warning(high_variance_cv_check_results[0])
if high_variance_cv_check_results[DataCheckMessageType.WARNING]:
logger.warning(high_variance_cv_check_results[DataCheckMessageType.WARNING][0])
high_variance_cv = True

self._results['pipeline_results'][pipeline_id] = {
Expand Down
17 changes: 11 additions & 6 deletions evalml/data_checks/class_imbalance_data_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .data_check import DataCheck
from .data_check_message import DataCheckError, DataCheckWarning
from .data_check_message_type import DataCheckMessageType


class ClassImbalanceDataCheck(DataCheck):
Expand Down Expand Up @@ -33,29 +34,33 @@ def validate(self, X, y):
y: Target labels to check for imbalanced data.

Returns:
list (DataCheckWarning, DataCheckError): list with DataCheckWarnings if imbalance in classes is less than the threshold,
and DataCheckErrors if the number of values for each target is below 2 * num_cv_folds.
dict: Dictionary with DataCheckWarnings if imbalance in classes is less than the threshold,
and DataCheckErrors if the number of values for each target is below 2 * num_cv_folds.

Example:
>>> X = pd.DataFrame({})
>>> y = pd.Series([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
>>> target_check = ClassImbalanceDataCheck(threshold=0.10)
>>> assert target_check.validate(X, y) == [DataCheckError("The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0]", "ClassImbalanceDataCheck"), DataCheckWarning("The following labels fall below 10% of the target: [0]", "ClassImbalanceDataCheck")]
>>> assert target_check.validate(X, y) == {DataCheckMessageType.ERROR: [DataCheckError("The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0]", "ClassImbalanceDataCheck")],\
DataCheckMessageType.WARNING: [DataCheckWarning("The following labels fall below 10% of the target: [0]", "ClassImbalanceDataCheck")]}
"""
messages = {
Copy link
Contributor

@freddyaboulton freddyaboulton Nov 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the downside of using DataCheckMessageType as the keys is that the user needs to import the enum in order to look at the warnings and results. I think that can be annoying, especially if the user is running automl and trying to access data_checks_results:

  1. Creating a DataCheckResults class.
  2. Using string values like "warning" and "error" as keys in the dict you have here.

The benefits of 1 is thatL

  1. We can expose the warnings and errors as instance properties and the user doesn't have to import another class.
  2. We can also have an api for checking if there are warnings or errors, or it's empty as opposed to if self._data_check_results[DataCheckMessageType.ERROR]. The nice thing about this is that if we ever add data to the dict, the is_empty wouldn't break users code downstream but the dict would because they'd be checking if results == {DataCheckMessageType.WARNING: [], DataCheckMessageType.ERROR: []}
    The downside is that it's another class but our docs already show how to get data from the data checks results.

The benefit of 2 is that it'd be a small change but would definitely lead to typos lol.

My vote is for 1 I think it'd be fine to keep it as-is for now too. I saw that you left to use DataCheckMessageType as key? in the PR description so wanted to offer my thoughts lol.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooo I like this suggestion a lot. I agree, the reason why I left the comment about using DataCheckMessageType as key was because as I was updating the code, I too felt the inconvenience / frustration of having to import the enum everywhere, but as @dsherry had mentioned, since we already have the enums in place we might as well use them.

That being said, I like the idea of creating a separate DataCheckResults class, and having warnings and errors as attributes 🤔 That way, the user doesn't need to directly type in the keys as strings, and any typos will result in an AttributeError instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@freddyaboulton @angela97lin sure I'm on board with having a DataCheckResults class! This would make it easy to access the errors and warnings. We could also define a to_json method which returns native python types instead of DataCheckError/DataCheckWarning instances.

Is your proposal that we do that instead of merging this PR?

UPDATE: we discussed this, further changes tracked by #1430.

DataCheckMessageType.WARNING: [],
DataCheckMessageType.ERROR: []
}
if not isinstance(y, pd.Series):
y = pd.Series(y)
messages = []
fold_counts = y.value_counts(normalize=False)
# search for targets that occur less than twice the number of cv folds first
below_threshold_folds = fold_counts.where(fold_counts < self.cv_folds).dropna()
if len(below_threshold_folds):
error_msg = "The number of instances of these targets is less than 2 * the number of cross folds = {} instances: {}"
messages.append(DataCheckError(error_msg.format(self.cv_folds, below_threshold_folds.index.tolist()), self.name))
messages[DataCheckMessageType.ERROR].append(DataCheckError(error_msg.format(self.cv_folds, below_threshold_folds.index.tolist()), self.name))

counts = fold_counts / fold_counts.sum()
below_threshold = counts.where(counts < self.threshold).dropna()
# if there are items that occur less than the threshold, add them to the list of messages
if len(below_threshold):
warning_msg = "The following labels fall below {:.0f}% of the target: {}"
messages.append(DataCheckWarning(warning_msg.format(self.threshold * 100, below_threshold.index.tolist()), self.name))
messages[DataCheckMessageType.WARNING].append(DataCheckWarning(warning_msg.format(self.threshold * 100, below_threshold.index.tolist()), self.name))
return messages
2 changes: 1 addition & 1 deletion evalml/data_checks/data_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ def validate(self, X, y=None):
y (pd.Series, optional): the target data of length [n_samples]

Returns:
list (DataCheckMessage): list of DataCheckError and DataCheckWarning objects
dict (DataCheckMessage): Dictionary of DataCheckError and DataCheckWarning messages
"""
11 changes: 8 additions & 3 deletions evalml/data_checks/data_checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect

from .data_check import DataCheck
from .data_check_message_type import DataCheckMessageType

from evalml.exceptions import DataCheckInitError

Expand Down Expand Up @@ -81,13 +82,17 @@ def validate(self, X, y=None):
y (pd.Series): The target data of length [n_samples]

Returns:
list (DataCheckMessage): List containing DataCheckMessage objects
dict: Dictionary containing DataCheckMessage objects

"""
messages = []
messages = {
DataCheckMessageType.WARNING: [],
DataCheckMessageType.ERROR: []
}
for data_check in self.data_checks:
messages_new = data_check.validate(X, y)
messages.extend(messages_new)
messages[DataCheckMessageType.WARNING].extend(messages_new[DataCheckMessageType.WARNING])
messages[DataCheckMessageType.ERROR].extend(messages_new[DataCheckMessageType.ERROR])
angela97lin marked this conversation as resolved.
Show resolved Hide resolved
return messages


Expand Down
1 change: 1 addition & 0 deletions evalml/data_checks/default_data_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class DefaultDataChecks(DataChecks):
def __init__(self, problem_type):
"""
A collection of basic data checks.

Arguments:
problem_type (str): The problem type that is being validated. Can be regression, binary, or multiclass.
"""
Expand Down
13 changes: 9 additions & 4 deletions evalml/data_checks/high_variance_cv_data_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .data_check import DataCheck
from .data_check_message import DataCheckWarning
from .data_check_message_type import DataCheckMessageType


class HighVarianceCVDataCheck(DataCheck):
Expand All @@ -26,23 +27,27 @@ def validate(self, pipeline_name, cv_scores):
cv_scores (pd.Series, np.ndarray, list): list of scores of each cross-validation fold

Returns:
list (DataCheckWarning): list with DataCheckWarnings if imbalance in classes is less than the threshold.
dict: Dictionary with DataCheckWarnings if imbalance in classes is less than the threshold.

Example:
>>> cv_scores = pd.Series([0, 1, 1, 1])
>>> check = HighVarianceCVDataCheck(threshold=0.10)
>>> assert check.validate("LogisticRegressionPipeline", cv_scores) == [DataCheckWarning("High coefficient of variation (cv >= 0.1) within cross validation scores. LogisticRegressionPipeline may not perform as estimated on unseen data.", "HighVarianceCVDataCheck")]
>>> assert check.validate("LogisticRegressionPipeline", cv_scores) == {DataCheckMessageType.WARNING: [DataCheckWarning("High coefficient of variation (cv >= 0.1) within cross validation scores. LogisticRegressionPipeline may not perform as estimated on unseen data.", "HighVarianceCVDataCheck")],\
DataCheckMessageType.ERROR: []}
"""
messages = {
DataCheckMessageType.WARNING: [],
DataCheckMessageType.ERROR: []
}
if not isinstance(cv_scores, pd.Series):
cv_scores = pd.Series(cv_scores)

messages = []
if cv_scores.mean() == 0:
high_variance_cv = 0
else:
high_variance_cv = abs(cv_scores.std() / cv_scores.mean()) > self.threshold
# if there are items that occur less than the threshold, add them to the list of messages
if high_variance_cv:
warning_msg = f"High coefficient of variation (cv >= {self.threshold}) within cross validation scores. {pipeline_name} may not perform as estimated on unseen data."
messages.append(DataCheckWarning(warning_msg, self.name))
messages[DataCheckMessageType.WARNING].append(DataCheckWarning(warning_msg, self.name))
return messages
Loading