diff --git a/awswrangler/data_quality/__init__.py b/awswrangler/data_quality/__init__.py index a88baa097..131744413 100644 --- a/awswrangler/data_quality/__init__.py +++ b/awswrangler/data_quality/__init__.py @@ -1,9 +1,17 @@ """AWS Glue Data Quality package.""" -from awswrangler.data_quality._create import create_ruleset, evaluate_ruleset, update_ruleset +from awswrangler.data_quality._create import ( + create_recommendation_ruleset, + create_ruleset, + evaluate_ruleset, + update_ruleset, +) +from awswrangler.data_quality._get import get_ruleset __all__ = [ + "create_recommendation_ruleset", "create_ruleset", "evaluate_ruleset", + "get_ruleset", "update_ruleset", ] diff --git a/awswrangler/data_quality/_create.py b/awswrangler/data_quality/_create.py index 69f6b29c1..3806fc657 100644 --- a/awswrangler/data_quality/_create.py +++ b/awswrangler/data_quality/_create.py @@ -1,8 +1,9 @@ """AWS Glue Data Quality Create module.""" import logging +import pprint import uuid -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast import boto3 import pandas as pd @@ -10,9 +11,11 @@ from awswrangler import _utils, exceptions from awswrangler._config import apply_configs from awswrangler.data_quality._utils import ( + _create_datasource, _get_data_quality_results, + _rules_to_df, _start_ruleset_evaluation_run, - _wait_ruleset_evaluation_run, + _wait_ruleset_run, ) _logger: logging.Logger = logging.getLogger(__name__) @@ -58,9 +61,9 @@ def create_ruleset( description : str Ruleset description. client_token : str, optional - Random id used for idempotency. Will be automatically generated if not provided. + Random id used for idempotency. Is automatically generated if not provided. boto3_session : boto3.Session, optional - Ruleset description. + Boto3 Session. If none, the default boto3 session is used. Examples -------- @@ -93,7 +96,7 @@ def create_ruleset( >>> df_rules=df_rules, >>>) """ - if df_rules is not None and dqdl_rules: + if (df_rules is not None and dqdl_rules) or (df_rules is None and not dqdl_rules): raise exceptions.InvalidArgumentCombination("You must pass either ruleset `df_rules` or `dqdl_rules`.") client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) @@ -110,8 +113,8 @@ def create_ruleset( }, ClientToken=client_token if client_token else uuid.uuid4().hex, ) - except client_glue.exceptions.AlreadyExistsException: - raise exceptions.AlreadyExists(f"Ruleset {name} already exists.") + except client_glue.exceptions.AlreadyExistsException as not_found: + raise exceptions.AlreadyExists(f"Ruleset {name} already exists.") from not_found @apply_configs @@ -139,9 +142,9 @@ def update_ruleset( description : str Ruleset description. client_token : str, optional - Random id used for idempotency. Will be automatically generated if not provided. + Random id used for idempotency. Is automatically generated if not provided. boto3_session : boto3.Session, optional - Ruleset description. + Boto3 Session. If none, the default boto3 session is used. Examples -------- @@ -151,7 +154,7 @@ def update_ruleset( >>> dqdl_rules="Rules = [ RowCount between 1 and 3 ]", >>>) """ - if df_rules is not None and dqdl_rules: + if (df_rules is not None and dqdl_rules) or (df_rules is None and not dqdl_rules): raise exceptions.InvalidArgumentCombination("You must pass either ruleset `df_rules` or `dqdl_rules`.") client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) @@ -165,8 +168,99 @@ def update_ruleset( Ruleset=dqdl_rules, ClientToken=client_token if client_token else uuid.uuid4().hex, ) - except client_glue.exceptions.EntityNotFoundException: - raise exceptions.ResourceDoesNotExist(f"Ruleset {name} does not exist.") + except client_glue.exceptions.EntityNotFoundException as not_found: + raise exceptions.ResourceDoesNotExist(f"Ruleset {name} does not exist.") from not_found + + +@apply_configs +def create_recommendation_ruleset( + database: str, + table: str, + iam_role_arn: str, + name: Optional[str] = None, + catalog_id: Optional[str] = None, + connection_name: Optional[str] = None, + additional_options: Optional[Dict[str, Any]] = None, + number_of_workers: int = 5, + timeout: int = 2880, + client_token: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, +) -> pd.DataFrame: + """Create recommendation Data Quality ruleset. + + Parameters + ---------- + database : str + Glue database name. + table : str + Glue table name. + iam_role_arn : str + IAM Role ARN. + name : str, optional + Ruleset name. + catalog_id : str, optional + Glue Catalog id. + connection_name : str, optional + Glue connection name. + additional_options : dict, optional + Additional options for the table. Supported keys: + `pushDownPredicate`: to filter on partitions without having to list and read all the files in your dataset. + `catalogPartitionPredicate`: to use server-side partition pruning using partition indexes in the + Glue Data Catalog. + number_of_workers: int, optional + The number of G.1X workers to be used in the run. The default is 5. + timeout: int, optional + The timeout for a run in minutes. The default is 2880 (48 hours). + client_token : str, optional + Random id used for idempotency. Is automatically generated if not provided. + boto3_session : boto3.Session, optional + Boto3 Session. If none, the default boto3 session is used. + + Returns + ------- + pd.DataFrame + Data frame with recommended ruleset details. + + Examples + -------- + >>> import awswrangler as wr + + >>> df_recommended_ruleset = wr.data_quality.create_recommendation_ruleset( + >>> database="database", + >>> table="table", + >>> iam_role_arn="arn:...", + >>>) + """ + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + + args: Dict[str, Any] = { + "DataSource": _create_datasource( + database=database, + table=table, + catalog_id=catalog_id, + connection_name=connection_name, + additional_options=additional_options, + ), + "Role": iam_role_arn, + "NumberOfWorkers": number_of_workers, + "Timeout": timeout, + "ClientToken": client_token if client_token else uuid.uuid4().hex, + } + if name: + args["CreatedRulesetName"] = name + _logger.debug("args: \n%s", pprint.pformat(args)) + run_id: str = cast(str, client_glue.start_data_quality_rule_recommendation_run(**args)["RunId"]) + + _logger.debug("run_id: %s", run_id) + dqdl_recommended_rules: str = cast( + str, + _wait_ruleset_run( + run_id=run_id, + run_type="recommendation", + boto3_session=boto3_session, + )["RecommendedRuleset"], + ) + return _rules_to_df(rules=dqdl_recommended_rules) @apply_configs @@ -178,7 +272,7 @@ def evaluate_ruleset( database: Optional[str] = None, table: Optional[str] = None, catalog_id: Optional[str] = None, - connection: Optional[str] = None, + connection_name: Optional[str] = None, additional_options: Optional[Dict[str, str]] = None, additional_run_options: Optional[Dict[str, str]] = None, client_token: Optional[str] = None, @@ -188,27 +282,27 @@ def evaluate_ruleset( Parameters ---------- - name : str - Ruleset name. + name : str or list[str] + Ruleset name or list of names. iam_role_arn : str - IAM Role. + IAM Role ARN. number_of_workers: int, optional The number of G.1X workers to be used in the run. The default is 5. timeout: int, optional The timeout for a run in minutes. The default is 2880 (48 hours). database : str, optional Glue database name. Database associated with the ruleset will be used if not provided. - table : str, optinal + table : str, optional Glue table name. Table associated with the ruleset will be used if not provided. catalog_id : str, optional Glue Catalog id. - connection : str, optional - Glue connection. - additional_options : Dict[str, str], optional + connection_name : str, optional + Glue connection name. + additional_options : dict, optional Additional options for the table. Supported keys: `pushDownPredicate`: to filter on partitions without having to list and read all the files in your dataset. - `catalogPartitionPredicate`: to use server-side partition pruning using partition indexes in the - Glue Data Catalog. + `catalogPartitionPredicate`: to use server-side partition pruning using partition indexes in the + Glue Data Catalog. additional_run_options : Dict[str, str], optional Additional run options. Supported keys: `CloudWatchMetricsEnabled`: whether to enable CloudWatch metrics. @@ -216,7 +310,12 @@ def evaluate_ruleset( client_token : str, optional Random id used for idempotency. Will be automatically generated if not provided. boto3_session : boto3.Session, optional - Ruleset description. + Boto3 Session. If none, the default boto3 session is used. + + Returns + ------- + pd.DataFrame + Data frame with ruleset evaluation results. Examples -------- @@ -231,8 +330,8 @@ def evaluate_ruleset( >>> table="table", >>> dqdl_rules="Rules = [ RowCount between 1 and 3 ]", >>>) - >>> wr.data_quality.evaluate_ruleset( - >>> name="ruleset", + >>> df_ruleset_results = wr.data_quality.evaluate_ruleset( + >>> name=["ruleset1", "rulseset2"], >>> iam_role_arn=glue_data_quality_role, >>> ) """ @@ -244,12 +343,19 @@ def evaluate_ruleset( database=database, table=table, catalog_id=catalog_id, - connection=connection, + connection_name=connection_name, additional_options=additional_options, additional_run_options=additional_run_options, client_token=client_token if client_token else uuid.uuid4().hex, boto3_session=boto3_session, ) _logger.debug("run_id: %s", run_id) - result_ids: List[str] = _wait_ruleset_evaluation_run(run_id=run_id, boto3_session=boto3_session) + result_ids: List[str] = cast( + List[str], + _wait_ruleset_run( + run_id=run_id, + run_type="evaluation", + boto3_session=boto3_session, + )["ResultIds"], + ) return _get_data_quality_results(result_ids=result_ids, boto3_session=boto3_session) diff --git a/awswrangler/data_quality/_get.py b/awswrangler/data_quality/_get.py new file mode 100644 index 000000000..d903f1739 --- /dev/null +++ b/awswrangler/data_quality/_get.py @@ -0,0 +1,36 @@ +"""AWS Glue Data Quality Get Module.""" + +from typing import Optional, cast + +import boto3 +import pandas as pd + +from awswrangler.data_quality._utils import _get_ruleset, _rules_to_df + + +def get_ruleset( + name: str, + boto3_session: Optional[boto3.Session] = None, +) -> pd.DataFrame: + """Get a Data Quality ruleset. + + Parameters + ---------- + name : str + Ruleset name. + boto3_session : boto3.Session, optional + Boto3 Session. If none, the default boto3 session is used. + + Returns + ------- + pd.DataFrame + Data frame with ruleset details. + + Examples + -------- + >>> import awswrangler as wr + + >>> df_ruleset = wr.data_quality.get_ruleset(name="my_ruleset") + """ + rules = cast(str, _get_ruleset(ruleset_name=name, boto3_session=boto3_session)["Ruleset"]) + return _rules_to_df(rules=rules) diff --git a/awswrangler/data_quality/_utils.py b/awswrangler/data_quality/_utils.py index 9ce0be031..70706e58b 100644 --- a/awswrangler/data_quality/_utils.py +++ b/awswrangler/data_quality/_utils.py @@ -1,9 +1,11 @@ """AWS Glue Data Quality Utils module.""" +import ast import logging import pprint +import re import time -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast import boto3 import pandas as pd @@ -16,11 +18,34 @@ _RULESET_EVALUATION_WAIT_POLLING_DELAY: float = 0.25 # SECONDS +def _parse_rules(rules: List[str]) -> List[Tuple[str, Optional[str], Optional[str]]]: + parsed_rules: List[Tuple[str, Optional[str], Optional[str]]] = [] + for rule in rules: + rule_type, remainder = tuple(rule.split(maxsplit=1)) + if remainder.startswith('"'): + remainder_split = remainder.split(maxsplit=1) + parameter = remainder_split[0].strip('"') + expression = None if len(remainder_split) == 1 else remainder_split[1] + else: + parameter = None + expression = remainder + parsed_rules.append((rule_type, parameter, expression)) + return parsed_rules + + +def _rules_to_df(rules: str) -> pd.DataFrame: + rules = re.sub(r"^\s*Rules\s*=\s*\[\s*", "", rules) # remove Rules = [\n + rules = re.sub(r"\s*\]\s*$", "", rules) # remove \n] + rules = re.sub(r"\s*,\s*(?![^[]*])", "', '", rules) + list_rules = ast.literal_eval(f"['{rules}']") + return pd.DataFrame(_parse_rules(list_rules), columns=["rule_type", "parameter", "expression"]) + + def _create_datasource( database: str, table: str, catalog_id: Optional[str] = None, - connection: Optional[str] = None, + connection_name: Optional[str] = None, additional_options: Optional[Dict[str, str]] = None, ) -> Dict[str, Dict[str, str]]: datasource: Dict[str, Dict[str, Any]] = { @@ -31,8 +56,8 @@ def _create_datasource( } if catalog_id: datasource["GlueTable"]["CatalogId"] = catalog_id - if connection: - datasource["GlueTable"]["ConnectionName"] = connection + if connection_name: + datasource["GlueTable"]["ConnectionName"] = connection_name if additional_options: datasource["GlueTable"]["AdditionalOptions"] = additional_options return datasource @@ -46,7 +71,7 @@ def _start_ruleset_evaluation_run( database: Optional[str] = None, table: Optional[str] = None, catalog_id: Optional[str] = None, - connection: Optional[str] = None, + connection_name: Optional[str] = None, additional_options: Optional[Dict[str, str]] = None, additional_run_options: Optional[Dict[str, str]] = None, client_token: Optional[str] = None, @@ -63,7 +88,7 @@ def _start_ruleset_evaluation_run( database=database, table=table, catalog_id=catalog_id, - connection=connection, + connection_name=connection_name, additional_options=additional_options, ) args: Dict[str, Any] = { @@ -83,32 +108,38 @@ def _start_ruleset_evaluation_run( return cast(str, response["RunId"]) -def _get_ruleset_evaluation_run( +def _get_ruleset_run( run_id: str, + run_type: str, boto3_session: Optional[boto3.Session] = None, ) -> Dict[str, Any]: - boto3_session = _utils.ensure_session(session=boto3_session) - client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) - return cast(Dict[str, Any], client_glue.get_data_quality_ruleset_evaluation_run(RunId=run_id)) + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_glue: boto3.client = _utils.client(service_name="glue", session=session) + if run_type == "recommendation": + response = client_glue.get_data_quality_rule_recommendation_run(RunId=run_id) + elif run_type == "evaluation": + response = client_glue.get_data_quality_ruleset_evaluation_run(RunId=run_id) + return cast(Dict[str, Any], response) -def _wait_ruleset_evaluation_run( +def _wait_ruleset_run( run_id: str, + run_type: str, boto3_session: Optional[boto3.Session] = None, -) -> List[str]: +) -> Dict[str, Any]: session: boto3.Session = _utils.ensure_session(session=boto3_session) - response: Dict[str, Any] = _get_ruleset_evaluation_run(run_id=run_id, boto3_session=session) + response: Dict[str, Any] = _get_ruleset_run(run_id=run_id, run_type=run_type, boto3_session=session) status: str = response["Status"] while status not in _RULESET_EVALUATION_FINAL_STATUSES: time.sleep(_RULESET_EVALUATION_WAIT_POLLING_DELAY) - response = _get_ruleset_evaluation_run(run_id=run_id, boto3_session=session) + response = _get_ruleset_run(run_id=run_id, run_type=run_type, boto3_session=session) status = response["Status"] _logger.debug("status: %s", status) if status == "FAILED": raise exceptions.QueryFailed(response.get("ErrorString")) if status == "STOPPED": raise exceptions.QueryCancelled("Ruleset execution stopped") - return cast(List[str], response["ResultIds"]) + return response def _get_ruleset( diff --git a/docs/source/api.rst b/docs/source/api.rst index 5f59baca6..7768a7bfe 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -12,6 +12,7 @@ API Reference * `Oracle`_ * `Data API Redshift`_ * `Data API RDS`_ +* `AWS Glue Data Quality`_ * `OpenSearch`_ * `Amazon Neptune`_ * `DynamoDB`_ @@ -244,6 +245,20 @@ Data API RDS connect read_sql_query +AWS Glue Data Quality +--------------------- + +.. currentmodule:: awswrangler.data_quality + +.. autosummary:: + :toctree: stubs + + create_recommendation_ruleset + create_ruleset + evaluate_ruleset + get_ruleset + update_ruleset + OpenSearch ---------- diff --git a/tests/test_data_quality.py b/tests/test_data_quality.py index 7340a847a..073744c0b 100644 --- a/tests/test_data_quality.py +++ b/tests/test_data_quality.py @@ -17,9 +17,9 @@ def df(path, glue_database, glue_table): def test_ruleset_df(df, path, glue_database, glue_table, glue_ruleset, glue_data_quality_role): df_rules = pd.DataFrame( { - "rule_type": ["RowCount", "IsComplete", "Uniqueness"], - "parameter": [None, "c0", "c0"], - "expression": ["between 1 and 6", None, "> 0.95"], + "rule_type": ["RowCount", "IsComplete", "Uniqueness", "ColumnValues"], + "parameter": [None, "c0", "c0", "c1"], + "expression": ["between 1 and 6", None, "> 0.95", "in [0, 1, 2]"], } ) wr.data_quality.create_ruleset( @@ -28,11 +28,15 @@ def test_ruleset_df(df, path, glue_database, glue_table, glue_ruleset, glue_data table=glue_table, df_rules=df_rules, ) + df_ruleset = wr.data_quality.get_ruleset(name=glue_ruleset) + assert df_rules.equals(df_ruleset) + df_results = wr.data_quality.evaluate_ruleset( name=glue_ruleset, iam_role_arn=glue_data_quality_role, + number_of_workers=2, ) - assert df_results.shape == (3, 4) + assert df_results.shape == (4, 4) assert df_results["Result"].eq("PASS").all() @@ -59,6 +63,31 @@ def test_ruleset_dqdl(df, path, glue_database, glue_table, glue_ruleset, glue_da df_results = wr.data_quality.evaluate_ruleset( name=glue_ruleset, iam_role_arn=glue_data_quality_role, + number_of_workers=2, + ) + assert df_results["Result"].eq("PASS").all() + + +def test_recommendation_ruleset(df, path, glue_database, glue_table, glue_ruleset, glue_data_quality_role): + df_recommended_ruleset = wr.data_quality.create_recommendation_ruleset( + database=glue_database, + table=glue_table, + iam_role_arn=glue_data_quality_role, + number_of_workers=2, + ) + df_rules = df_recommended_ruleset.append( + {"rule_type": "ColumnValues", "parameter": "c2", "expression": "in [0, 1, 2]"}, ignore_index=True + ) + wr.data_quality.create_ruleset( + name=glue_ruleset, + database=glue_database, + table=glue_table, + df_rules=df_rules, + ) + df_results = wr.data_quality.evaluate_ruleset( + name=glue_ruleset, + iam_role_arn=glue_data_quality_role, + number_of_workers=2, ) assert df_results["Result"].eq("PASS").all() @@ -73,6 +102,7 @@ def test_ruleset_fail(df, path, glue_database, glue_table, glue_ruleset, glue_da df_results = wr.data_quality.evaluate_ruleset( name=glue_ruleset, iam_role_arn=glue_data_quality_role, + number_of_workers=2, ) assert df_results["Result"][0] == "FAIL"