Skip to content

Commit

Permalink
update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
angela97lin committed Nov 17, 2020
1 parent 86273fe commit 2b8e910
Showing 1 changed file with 63 additions and 24 deletions.
87 changes: 63 additions & 24 deletions docs/source/user_guide/data_checks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from evalml.data_checks import DataCheckMessageType\n",
"\n",
"from evalml.data_checks import HighlyNullDataCheck\n",
"\n",
Expand All @@ -43,9 +44,13 @@
" [8, 6, np.nan]])\n",
"\n",
"null_check = HighlyNullDataCheck(pct_null_threshold=0.8)\n",
"results = null_check.validate(X)\n",
"\n",
"for message in null_check.validate(X):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand Down Expand Up @@ -84,9 +89,13 @@
" \"good col\":[0, 4, 1]})\n",
"y = pd.Series([1, 0, 1])\n",
"no_variance_data_check = NoVarianceDataCheck()\n",
"results = no_variance_data_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in no_variance_data_check.validate(X, y):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -110,9 +119,13 @@
"y = pd.Series([1, 0, 1])\n",
"\n",
"no_variance_data_check = NoVarianceDataCheck(count_nan_as_value=True)\n",
"results = no_variance_data_check.validate(X, y)\n",
"\n",
"for message in no_variance_data_check.validate(X, y):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -139,10 +152,15 @@
" [4, 4, 8, 3],\n",
" [9, 2, 7, 1]])\n",
"y = pd.Series([0, 1, 1, 1, 1])\n",
"\n",
"class_imbalance_check = ClassImbalanceDataCheck(threshold=0.25, num_cv_folds=4)\n",
"results = class_imbalance_check.validate(X, y)\n",
"\n",
"for message in class_imbalance_check.validate(X, y):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -167,8 +185,13 @@
"y = pd.Series([10, 42, 31, 51, 40])\n",
"\n",
"target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.8)\n",
"for message in target_leakage_check.validate(X, y):\n",
" print(message.message)"
"results = target_leakage_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -194,10 +217,15 @@
" \n",
"X = pd.DataFrame({})\n",
"y = pd.Series([0, 1, None, None])\n",
"\n",
"invalid_target_check = InvalidTargetDataCheck('binary')\n",
"results = invalid_target_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in invalid_target_check.validate(X, y):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -218,10 +246,15 @@
"from evalml.data_checks import IDColumnsDataCheck\n",
"\n",
"X = pd.DataFrame([[0, 53, 6325, 5],[1, 90, 6325, 10],[2, 90, 18, 20]], columns=['user_number', 'cost', 'revenue', 'id'])\n",
"\n",
"id_col_check = IDColumnsDataCheck(id_threshold=0.9)\n",
"results = id_col_check.validate(X, y)\n",
"\n",
"for message in id_col_check.validate(X):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand Down Expand Up @@ -282,9 +315,13 @@
"from evalml.data_checks import OutliersDataCheck\n",
"\n",
"outliers_check = OutliersDataCheck()\n",
"results = outliers_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in outliers_check.validate(X):\n",
" print(message.message)"
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -310,12 +347,13 @@
"y = pd.Series([1, 0, 1])\n",
"\n",
"no_variance_data_check = NoVarianceDataCheck(count_nan_as_value=True)\n",
"results = no_variance_data_check.validate(X, y)\n",
"\n",
"for message in results[DataCheckMessageType.WARNING]:\n",
" print(\"Warning:\", message.message)\n",
"\n",
"for message in no_variance_data_check.validate(X, y):\n",
" if isinstance(message, DataCheckError):\n",
" print(\"ERROR:\", message.message)\n",
" elif isinstance(message, DataCheckWarning):\n",
" print(\"WARNING:\", message.message)"
"for message in results[DataCheckMessageType.ERROR]:\n",
" print(\"Error:\", message.message)"
]
},
{
Expand All @@ -329,7 +367,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If you would prefer to write your own data check, you can do so by extending the `DataCheck` class and implementing the `validate(self, X, y)` class method. Below, we've created a new `DataCheck`, `ZeroVarianceDataCheck`, which is similar to `NoVarianceDataCheck` defined in `EvalML`."
"If you would prefer to write your own data check, you can do so by extending the `DataCheck` class and implementing the `validate(self, X, y)` class method. Below, we've created a new `DataCheck`, `ZeroVarianceDataCheck`, which is similar to `NoVarianceDataCheck` defined in `EvalML`. The `validate(self, X, y)` method should return a dictionary with `DataCheckMessageType.WARNING` and `DataCheckMessageType.ERROR` as keys mapping to list of warnings and errors, respectively."
]
},
{
Expand All @@ -345,10 +383,11 @@
"\n",
"class ZeroVarianceDataCheck(DataCheck):\n",
" def validate(self, X, y):\n",
" messages = {DataCheckMessageType.WARNING: [], DataCheckMessageType.ERROR: []}\n",
" if not isinstance(X, pd.DataFrame):\n",
" X = pd.DataFrame(X)\n",
" warning_msg = \"Column '{}' has zero variance\"\n",
" return [DataCheckError(warning_msg.format(column), self.name) for column in X.columns if len(X[column].unique()) == 1]"
" messages[DataCheckMessageType.WARNING].extend([DataCheckError(warning_msg.format(column), self.name) for column in X.columns if len(X[column].unique()) == 1])"
]
},
{
Expand Down Expand Up @@ -458,4 +497,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

0 comments on commit 2b8e910

Please sign in to comment.