From 2b8e91087bb4419a10820f1be738ababab1f5fae Mon Sep 17 00:00:00 2001 From: Angela Lin Date: Tue, 17 Nov 2020 12:57:57 -0500 Subject: [PATCH] update notebook --- docs/source/user_guide/data_checks.ipynb | 87 +++++++++++++++++------- 1 file changed, 63 insertions(+), 24 deletions(-) diff --git a/docs/source/user_guide/data_checks.ipynb b/docs/source/user_guide/data_checks.ipynb index 03e3f0d0a7..0d6c59d2e1 100644 --- a/docs/source/user_guide/data_checks.ipynb +++ b/docs/source/user_guide/data_checks.ipynb @@ -33,6 +33,7 @@ "source": [ "import numpy as np\n", "import pandas as pd\n", + "from evalml.data_checks import DataCheckMessageType\n", "\n", "from evalml.data_checks import HighlyNullDataCheck\n", "\n", @@ -43,9 +44,13 @@ " [8, 6, np.nan]])\n", "\n", "null_check = HighlyNullDataCheck(pct_null_threshold=0.8)\n", + "results = null_check.validate(X)\n", "\n", - "for message in null_check.validate(X):\n", - " print(message.message)" + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", + "\n", + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -84,9 +89,13 @@ " \"good col\":[0, 4, 1]})\n", "y = pd.Series([1, 0, 1])\n", "no_variance_data_check = NoVarianceDataCheck()\n", + "results = no_variance_data_check.validate(X, y)\n", + "\n", + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", "\n", - "for message in no_variance_data_check.validate(X, y):\n", - " print(message.message)" + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -110,9 +119,13 @@ "y = pd.Series([1, 0, 1])\n", "\n", "no_variance_data_check = NoVarianceDataCheck(count_nan_as_value=True)\n", + "results = no_variance_data_check.validate(X, y)\n", "\n", - "for message in no_variance_data_check.validate(X, y):\n", - " print(message.message)" + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", + "\n", + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -139,10 +152,15 @@ " [4, 4, 8, 3],\n", " [9, 2, 7, 1]])\n", "y = pd.Series([0, 1, 1, 1, 1])\n", + "\n", "class_imbalance_check = ClassImbalanceDataCheck(threshold=0.25, num_cv_folds=4)\n", + "results = class_imbalance_check.validate(X, y)\n", "\n", - "for message in class_imbalance_check.validate(X, y):\n", - " print(message.message)" + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", + "\n", + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -167,8 +185,13 @@ "y = pd.Series([10, 42, 31, 51, 40])\n", "\n", "target_leakage_check = TargetLeakageDataCheck(pct_corr_threshold=0.8)\n", - "for message in target_leakage_check.validate(X, y):\n", - " print(message.message)" + "results = target_leakage_check.validate(X, y)\n", + "\n", + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", + "\n", + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -194,10 +217,15 @@ " \n", "X = pd.DataFrame({})\n", "y = pd.Series([0, 1, None, None])\n", + "\n", "invalid_target_check = InvalidTargetDataCheck('binary')\n", + "results = invalid_target_check.validate(X, y)\n", + "\n", + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", "\n", - "for message in invalid_target_check.validate(X, y):\n", - " print(message.message)" + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -218,10 +246,15 @@ "from evalml.data_checks import IDColumnsDataCheck\n", "\n", "X = pd.DataFrame([[0, 53, 6325, 5],[1, 90, 6325, 10],[2, 90, 18, 20]], columns=['user_number', 'cost', 'revenue', 'id'])\n", + "\n", "id_col_check = IDColumnsDataCheck(id_threshold=0.9)\n", + "results = id_col_check.validate(X, y)\n", "\n", - "for message in id_col_check.validate(X):\n", - " print(message.message)" + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", + "\n", + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -282,9 +315,13 @@ "from evalml.data_checks import OutliersDataCheck\n", "\n", "outliers_check = OutliersDataCheck()\n", + "results = outliers_check.validate(X, y)\n", + "\n", + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", "\n", - "for message in outliers_check.validate(X):\n", - " print(message.message)" + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -310,12 +347,13 @@ "y = pd.Series([1, 0, 1])\n", "\n", "no_variance_data_check = NoVarianceDataCheck(count_nan_as_value=True)\n", + "results = no_variance_data_check.validate(X, y)\n", + "\n", + "for message in results[DataCheckMessageType.WARNING]:\n", + " print(\"Warning:\", message.message)\n", "\n", - "for message in no_variance_data_check.validate(X, y):\n", - " if isinstance(message, DataCheckError):\n", - " print(\"ERROR:\", message.message)\n", - " elif isinstance(message, DataCheckWarning):\n", - " print(\"WARNING:\", message.message)" + "for message in results[DataCheckMessageType.ERROR]:\n", + " print(\"Error:\", message.message)" ] }, { @@ -329,7 +367,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If you would prefer to write your own data check, you can do so by extending the `DataCheck` class and implementing the `validate(self, X, y)` class method. Below, we've created a new `DataCheck`, `ZeroVarianceDataCheck`, which is similar to `NoVarianceDataCheck` defined in `EvalML`." + "If you would prefer to write your own data check, you can do so by extending the `DataCheck` class and implementing the `validate(self, X, y)` class method. Below, we've created a new `DataCheck`, `ZeroVarianceDataCheck`, which is similar to `NoVarianceDataCheck` defined in `EvalML`. The `validate(self, X, y)` method should return a dictionary with `DataCheckMessageType.WARNING` and `DataCheckMessageType.ERROR` as keys mapping to list of warnings and errors, respectively." ] }, { @@ -345,10 +383,11 @@ "\n", "class ZeroVarianceDataCheck(DataCheck):\n", " def validate(self, X, y):\n", + " messages = {DataCheckMessageType.WARNING: [], DataCheckMessageType.ERROR: []}\n", " if not isinstance(X, pd.DataFrame):\n", " X = pd.DataFrame(X)\n", " warning_msg = \"Column '{}' has zero variance\"\n", - " return [DataCheckError(warning_msg.format(column), self.name) for column in X.columns if len(X[column].unique()) == 1]" + " messages[DataCheckMessageType.WARNING].extend([DataCheckError(warning_msg.format(column), self.name) for column in X.columns if len(X[column].unique()) == 1])" ] }, { @@ -458,4 +497,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +}