Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distribution data check #21

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
155 changes: 155 additions & 0 deletions checkmates/data_checks/checks/distribution_data_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Data check that checks if the target data contains certain distributions that may need to be transformed prior training to improve model performance."""
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
import diptest
import numpy as np
import woodwork as ww

from checkmates.data_checks import (
DataCheck,
DataCheckActionCode,
DataCheckActionOption,
DataCheckError,
DataCheckMessageCode,
DataCheckWarning,
)
from checkmates.utils import infer_feature_types


class DistributionDataCheck(DataCheck):
"""Check if the overall data contains certain distributions that may need to be transformed prior training to improve model performance. Uses the skew test and yeojohnson transformation."""

def validate(self, X, y):
"""Check if the overall data has a skewed or bimodal distribution.

Args:
X (pd.DataFrame, np.ndarray): Overall data to check for skewed or bimodal distributions.
y (pd.Series, np.ndarray): Target data to check for underlying distributions.

Returns:
dict (DataCheckError): List with DataCheckErrors if certain distributions are found in the overall data.

Examples:
>>> import pandas as pd

Features and target data that exhibit a skewed distribution will raise a warning for the user to transform the data.

>>> X = [5, 7, 8, 9, 10, 11, 12, 15, 20]
>>> data_check = DistributionDataCheck()
>>> assert data_check.validate(X, y) == [
... {
... "message": "Data may have a skewed distribution.",
... "data_check_name": "DistributionDataCheck",
... "level": "warning",
... "code": "SKEWED_DISTRIBUTION",
... "details": {"distribution type": "positive skew", "Skew Value": 0.7939, "Bimodal Coefficient": 1.0,},
... "action_options": [
... {
... "code": "TRANSFORM_TARGET",
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
... "data_check_name": "DistributionDataCheck",
... "parameters": {},
... "metadata": {
"is_skew": True,
"transformation_strategy": "yeojohnson",
... }
... }
... ]
... }
... ]
...
>>> X = pd.Series([1, 1, 1, 2, 2, 3, 4, 4, 5, 5, 5])
>>> assert target_check.validate(X, y) == []
...
...
>>> X = pd.Series(pd.date_range("1/1/21", periods=10))
>>> assert target_check.validate(X, y) == [
... {
... "message": "Target is unsupported datetime type. Valid Woodwork logical types include: integer, double, age, age_fractional",
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
... "data_check_name": "DistributionDataCheck",
... "level": "error",
... "details": {"columns": None, "rows": None, "unsupported_type": "datetime"},
... "code": "TARGET_UNSUPPORTED_TYPE",
... "action_options": []
... }
... ]
"""
messages = []

if y is None:
messages.append(
DataCheckError(
message="Data is None",
data_check_name=self.name,
message_code=DataCheckMessageCode.TARGET_IS_NONE,
details={},
).to_dict(),
)
return messages

y = infer_feature_types(y)
allowed_types = [
ww.logical_types.Integer.type_string,
ww.logical_types.Double.type_string,
ww.logical_types.Age.type_string,
ww.logical_types.AgeFractional.type_string,
]
is_supported_type = y.ww.logical_type.type_string in allowed_types

if not is_supported_type:
messages.append(
DataCheckError(
message="Target is unsupported {} type. Valid Woodwork logical types include: {}".format(
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
y.ww.logical_type.type_string,
", ".join([ltype for ltype in allowed_types]),
),
data_check_name=self.name,
message_code=DataCheckMessageCode.TARGET_UNSUPPORTED_TYPE,
details={"unsupported_type": X.ww.logical_type.type_string},
).to_dict(),
)
return messages

(
is_skew,
distribution_type,
skew_value,
coef,
) = _detect_skew_distribution_helper(X)

if is_skew:
details = {
"distribution type": distribution_type,
"Skew Value": skew_value,
"Bimodal Coefficient": coef,
}
messages.append(
DataCheckWarning(
message="Data may have a skewed distribution.",
data_check_name=self.name,
message_code=DataCheckMessageCode.SKEWED_DISTRIBUTION,
details=details,
action_options=[
DataCheckActionOption(
DataCheckActionCode.TRANSFORM_TARGET,
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
data_check_name=self.name,
metadata={
"is_skew": True,
"transformation_strategy": "yeojohnson",
},
),
],
).to_dict(),
)
return messages


def _detect_skew_distribution_helper(X):
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
"""Helper method to detect skewed or bimodal distribution. Returns boolean, distribution type, the skew value, and bimodal coefficient."""
skew_value = np.stats.skew(X)
coef = diptest.diptest(X)[1]

if coef < 0.05:
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
return True, "bimodal distribution", skew_value, coef
if skew_value < -0.5:
return True, "negative skew", skew_value, coef
if skew_value > 0.5:
return True, "positive skew", skew_value, coef
return False, "no skew", skew_value, coef
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class DataCheckMessageCode(Enum):
TARGET_LOGNORMAL_DISTRIBUTION = "target_lognormal_distribution"
"""Message code for target data with a lognormal distribution."""

SKEWED_DISTRIBUTION = "skewed_distribution"
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
"""Message code for data with a skewed distribution."""

HIGH_VARIANCE = "high_variance"
"""Message code for when high variance is detected for cross-validation."""

Expand Down
41 changes: 41 additions & 0 deletions checkmates/pipelines/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pandas as pd
import woodwork
from scipy.stats import yeojohnson
from sklearn.impute import SimpleImputer as SkImputer

from checkmates.exceptions import MethodPropertyNotFoundError
Expand Down Expand Up @@ -83,6 +84,46 @@ def _get_feature_provenance(self):
return {}


"""Component that normalizes skewed distributions using the Yeo-Johnson method"""


class SimpleNormalizer(Transformer):
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
"""Normalizes skewed data according to the Yeo-Johnson method."""

def __init__(self):
super().__init__(
parameters=None,
)

def transform(self, X, y=None):
"""Transforms input by normalizing distribution.

Args:
X (pd.DataFrame): Data to transform.
y (pd.Series, optional): Target Data

Returns:
pd.DataFrame: Transformed X
"""
# Transform the data
X_t = yeojohnson(X)

# Reinit woodwork
X_t.ww.init()

def fit_transform(self, X, y=None):
"""Fits on X and transforms X.

Args:
X (pd.DataFrame): Data to fit and transform
y (pd.Series, optional): Target data.

Returns:
pd.DataFrame: Transformed X
"""
return self.fit(X, y).transform(X, y)


"""Component that imputes missing data according to a specified imputation strategy."""


Expand Down
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Release Notes
-------------
**Future Releases**
* Enhancements
* Created ``distribution_data_check`` to screen for positive and negative skews as well as bimodal distributions :pr:`21`
* Fixes
* Changes
* Documentation Changes
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ dependencies = [
"woodwork>=0.22.0",
"click>=8.0.0",
"black[jupyter]>=22.3.0",
"diptest>=0.5.2",
"scipy>=1.9.3",
NabilFayak marked this conversation as resolved.
Show resolved Hide resolved
]
requires-python = ">=3.8,<4.0"
readme = "README.md"
Expand Down