Skip to content

Commit

Permalink
Change TargetLeakageDataCheck to use Woodwork's mutual_information me…
Browse files Browse the repository at this point in the history
…thod (#1616)

* implement ww

* update release notes

* fix implementation

* fix test

* fix runtime and add pearson corr

* fix indexing

* linting

* coverage

* fix release notes

* update pearson implementation

* update documentation

* remove unnecessary comments

* fix input

* undergoing changes:

* simplify implementation

* fixing cast

* change bool to string input for init
  • Loading branch information
bchen1116 committed Jan 14, 2021
1 parent 00eac40 commit ecf8765
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 31 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Release Notes
* Added support for list inputs for objectives :pr:`1663`
* Added support for ``AutoMLSearch`` to handle time series classification pipelines :pr:`1666`
* Fixes
* Fixed ``TargetLeakageDataCheck`` to use Woodwork ``mutual_information`` rather than using Pandas' Pearson Correlation :pr:`1616`
* Fixed thresholding for pipelines in ``AutoMLSearch`` to only threshold binary classification pipelines :pr:`1622` :pr:`1626`
* Updated ``load_data`` to return Woodwork structures and update default parameter value for ``index`` to ``None`` :pr:`1610`
* Pinned scipy at < 1.6.0 while we work on adding support :pr:`1629`
Expand Down
60 changes: 43 additions & 17 deletions evalml/data_checks/target_leakage_data_check.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pandas as pd

from evalml.data_checks import (
DataCheck,
DataCheckMessageCode,
Expand All @@ -11,25 +13,51 @@


class TargetLeakageDataCheck(DataCheck):
"""Check if any of the features are highly correlated with the target."""
"""Check if any of the features are highly correlated with the target by using mutual information or Pearson correlation."""

def __init__(self, pct_corr_threshold=0.95):
"""Check if any of the features are highly correlated with the target.
def __init__(self, pct_corr_threshold=0.95, method="mutual"):
"""Check if any of the features are highly correlated with the target by using mutual information or Pearson correlation.
Currently only supports binary and numeric targets and features.
If `method='mutual'`, this data check uses mutual information and supports all target and feature types.
Otherwise, if `method='pearson'`, it uses Pearson correlation and only supports binary with numeric and boolean dtypes.
Pearson correlation returns a value in [-1, 1], while mutual information returns a value in [0, 1].
Arguments:
pct_corr_threshold (float): The correlation threshold to be considered leakage. Defaults to 0.95.
method (string): The method to determine correlation. Use 'mutual' for mutual information, otherwise 'pearson' for Pearson correlation. Defaults to 'mutual'.
"""
if pct_corr_threshold < 0 or pct_corr_threshold > 1:
raise ValueError("pct_corr_threshold must be a float between 0 and 1, inclusive.")
if method not in ['mutual', 'pearson']:
raise ValueError(f"Method '{method}' not in ['mutual', 'pearson']")
self.pct_corr_threshold = pct_corr_threshold
self.method = method

def _calculate_pearson(self, X, y):
highly_corr_cols = []
X_num = X.select(include=numeric_and_boolean_ww)
if y.logical_type not in numeric_and_boolean_ww or len(X_num.columns) == 0:
return highly_corr_cols
X_num = _convert_woodwork_types_wrapper(X_num.to_dataframe())
y = _convert_woodwork_types_wrapper(y.to_series())
highly_corr_cols = [label for label, col in X_num.iteritems() if abs(y.corr(col)) >= self.pct_corr_threshold]
return highly_corr_cols

def _calculate_mutual_information(self, X, y):
highly_corr_cols = []
for col in X.columns:
cols_to_compare = _convert_to_woodwork_structure(pd.DataFrame({col: X[col], str(col) + "y": y}))
mutual_info = cols_to_compare.mutual_information()
if len(mutual_info) > 0 and mutual_info['mutual_info'].iloc[0] > self.pct_corr_threshold:
highly_corr_cols.append(col)
return highly_corr_cols

def validate(self, X, y):
"""Check if any of the features are highly correlated with the target.
"""Check if any of the features are highly correlated with the target by using mutual information or Pearson correlation.
Currently only supports binary and numeric targets and features.
If `method='mutual'`, supports all target and feature types. Otherwise, if `method='pearson'` only supports binary with numeric and boolean dtypes.
Pearson correlation returns a value in [-1, 1], while mutual information returns a value in [0, 1].
Arguments:
X (ww.DataTable, pd.DataFrame, np.ndarray): The input features to check
Expand All @@ -43,11 +71,11 @@ def validate(self, X, y):
>>> X = pd.DataFrame({
... 'leak': [10, 42, 31, 51, 61],
... 'x': [42, 54, 12, 64, 12],
... 'y': [12, 5, 13, 74, 24],
... 'y': [13, 5, 13, 74, 24],
... })
>>> y = pd.Series([10, 42, 31, 51, 40])
>>> target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.8)
>>> assert target_leakage_check.validate(X, y) == {"warnings": [{"message": "Column 'leak' is 80.0% or more correlated with the target",\
>>> target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.95)
>>> assert target_leakage_check.validate(X, y) == {"warnings": [{"message": "Column 'leak' is 95.0% or more correlated with the target",\
"data_check_name": "TargetLeakageDataCheck",\
"level": "warning",\
"code": "TARGET_LEAKAGE",\
Expand All @@ -60,16 +88,14 @@ def validate(self, X, y):
}
X = _convert_to_woodwork_structure(X)
y = _convert_to_woodwork_structure(y)
if y.logical_type not in numeric_and_boolean_ww:
return messages
X_num = X.select(include=numeric_and_boolean_ww)
X_num = _convert_woodwork_types_wrapper(X_num.to_dataframe())
y = _convert_woodwork_types_wrapper(y.to_series())

if len(X_num.columns) == 0:
return messages
if self.method == 'pearson':
highly_corr_cols = self._calculate_pearson(X, y)
else:
X = _convert_woodwork_types_wrapper(X.to_dataframe())
y = _convert_woodwork_types_wrapper(y.to_series())
highly_corr_cols = self._calculate_mutual_information(X, y)

highly_corr_cols = {label: abs(y.corr(col)) for label, col in X_num.iteritems() if abs(y.corr(col)) >= self.pct_corr_threshold}
warning_msg = "Column '{}' is {}% or more correlated with the target"
messages["warnings"].extend([DataCheckWarning(message=warning_msg.format(col_name, self.pct_corr_threshold * 100),
data_check_name=self.name,
Expand Down
30 changes: 17 additions & 13 deletions evalml/tests/data_checks_tests/test_data_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,17 @@ def test_default_data_checks_classification(input_type):

data_checks = DefaultDataChecks("binary", get_default_primary_search_objective("binary"))

leakage = [DataCheckWarning(message="Column 'has_label_leakage' is 95.0% or more correlated with the target",
data_check_name="TargetLeakageDataCheck",
message_code=DataCheckMessageCode.TARGET_LEAKAGE,
details={"column": "has_label_leakage"}).to_dict()]
imbalance = [DataCheckError(message="The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [1.0, 0.0]",
data_check_name="ClassImbalanceDataCheck",
message_code=DataCheckMessageCode.CLASS_IMBALANCE_BELOW_FOLDS,
details={"target_values": [1.0, 0.0]}).to_dict()]

assert data_checks.validate(X, y) == {"warnings": messages[:3] + leakage, "errors": messages[3:] + imbalance}
assert data_checks.validate(X, y) == {"warnings": messages[:3], "errors": messages[3:] + imbalance}

data_checks = DataChecks(DefaultDataChecks._DEFAULT_DATA_CHECK_CLASSES,
{"InvalidTargetDataCheck": {"problem_type": "binary",
"objective": get_default_primary_search_objective("binary")}})
assert data_checks.validate(X, y) == {"warnings": messages[:3] + leakage, "errors": messages[3:]}
assert data_checks.validate(X, y) == {"warnings": messages[:3], "errors": messages[3:]}

# multiclass
imbalance = [DataCheckError(message="The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0.0, 2.0, 1.0]",
Expand Down Expand Up @@ -152,7 +148,7 @@ def test_default_data_checks_regression(input_type):
X = pd.DataFrame({'lots_of_null': [None, None, None, None, "some data"],
'all_null': [None, None, None, None, None],
'also_all_null': [None, None, None, None, None],
'no_null': [1, 2, 3, 4, 5],
'no_null': [1, 2, 3, 5, 5],
'id': [0, 1, 2, 3, 4],
'has_label_leakage': [100, 200, 100, 200, 100]})
y = pd.Series([0.3, 100.0, np.nan, 1.0, 0.2])
Expand All @@ -162,19 +158,27 @@ def test_default_data_checks_regression(input_type):
X = ww.DataTable(X)
y = ww.DataColumn(y)
y_no_variance = ww.DataColumn(y_no_variance)
id_leakage = [DataCheckWarning(message="Column 'id' is 95.0% or more correlated with the target",
data_check_name="TargetLeakageDataCheck",
message_code=DataCheckMessageCode.TARGET_LEAKAGE,
details={"column": "id"}).to_dict()]
null_leakage = [DataCheckWarning(message="Column 'lots_of_null' is 95.0% or more correlated with the target",
data_check_name="TargetLeakageDataCheck",
message_code=DataCheckMessageCode.TARGET_LEAKAGE,
details={"column": "lots_of_null"}).to_dict()]
data_checks = DefaultDataChecks("regression", get_default_primary_search_objective("regression"))
assert data_checks.validate(X, y) == {"warnings": messages[:3], "errors": messages[3:]}
assert data_checks.validate(X, y) == {"warnings": messages[:3] + id_leakage, "errors": messages[3:]}

# Skip Invalid Target
assert data_checks.validate(X, y_no_variance) == {"warnings": messages[:3], "errors": messages[4:] + [DataCheckError(message="Y has 1 unique value.",
data_check_name="NoVarianceDataCheck",
message_code=DataCheckMessageCode.NO_VARIANCE,
details={"column": "Y"}).to_dict()]}
assert data_checks.validate(X, y_no_variance) == {"warnings": messages[:3] + null_leakage, "errors": messages[4:] + [DataCheckError(message="Y has 1 unique value.",
data_check_name="NoVarianceDataCheck",
message_code=DataCheckMessageCode.NO_VARIANCE,
details={"column": "Y"}).to_dict()]}

data_checks = DataChecks(DefaultDataChecks._DEFAULT_DATA_CHECK_CLASSES,
{"InvalidTargetDataCheck": {"problem_type": "regression",
"objective": get_default_primary_search_objective("regression")}})
assert data_checks.validate(X, y) == {"warnings": messages[:3], "errors": messages[3:]}
assert data_checks.validate(X, y) == {"warnings": messages[:3] + id_leakage, "errors": messages[3:]}


def test_default_data_checks_time_series_regression():
Expand Down
Loading

0 comments on commit ecf8765

Please sign in to comment.