diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 81361573d..9e1027319 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -230,6 +230,17 @@ or ``./cloudformation/delete-databases.sh`` +### Enabling Lake Formation: +If your feature is related to AWS Lake Formation, there are a number of additional steps required in order to complete testing: + +1. In the AWS console, enable Lake Formation by setting your IAM role as an Administrator and by unchecking the boxes in the ``Data Catalog Settings`` section + +2. In the ``./cloudformation/base.yaml`` template file, set ``EnableLakeFormation`` to ``True``. Then run the ``./deploy-base.sh`` once more to add an AWS Glue Database and an S3 bucket registered with Lake Formation + +3. Back in the console, in the ``Data Locations`` section, grant your IAM role access to the S3 Lake Formation bucket (``s3://aws-wrangler-base-lakeformation...``) + +4. Finally, in the ``Data Permissions`` section, grant your IAM role ``Super`` permissions on both the ``aws_data_wrangler`` and ``aws_data_wrangler_lakeformation`` databases + ## Recommended Visual Studio Code Recommended setting ```json diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index 25785e433..6249471ff 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -15,6 +15,7 @@ dynamodb, emr, exceptions, + lakeformation, mysql, postgresql, quicksight, @@ -40,6 +41,7 @@ "s3", "sts", "redshift", + "lakeformation", "mysql", "postgresql", "secretsmanager", diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 25e1f2d59..f3f5fe06d 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -42,12 +42,13 @@ class _ConfigArg(NamedTuple): "redshift_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), "kms_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), "emr_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), + "lakeformation_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), # Botocore config "botocore_config": _ConfigArg(dtype=botocore.config.Config, nullable=True), } -class _Config: # pylint: disable=too-many-instance-attributes +class _Config: # pylint: disable=too-many-instance-attributes,too-many-public-methods """Wrangler's Configuration class.""" def __init__(self) -> None: @@ -60,6 +61,7 @@ def __init__(self) -> None: self.redshift_endpoint_url = None self.kms_endpoint_url = None self.emr_endpoint_url = None + self.lakeformation_endpoint_url = None self.botocore_config = None for name in _CONFIG_ARGS: self._load_config(name=name) @@ -342,6 +344,15 @@ def emr_endpoint_url(self) -> Optional[str]: def emr_endpoint_url(self, value: Optional[str]) -> None: self._set_config_value(key="emr_endpoint_url", value=value) + @property + def lakeformation_endpoint_url(self) -> Optional[str]: + """Property lakeformation_endpoint_url.""" + return cast(Optional[str], self["lakeformation_endpoint_url"]) + + @lakeformation_endpoint_url.setter + def lakeformation_endpoint_url(self, value: Optional[str]) -> None: + self._set_config_value(key="lakeformation_endpoint_url", value=value) + @property def botocore_config(self) -> botocore.config.Config: """Property botocore_config.""" diff --git a/awswrangler/_utils.py b/awswrangler/_utils.py index 50cd8fbd2..f70ea9d9f 100644 --- a/awswrangler/_utils.py +++ b/awswrangler/_utils.py @@ -87,6 +87,8 @@ def _get_endpoint_url(service_name: str) -> Optional[str]: endpoint_url = _config.config.kms_endpoint_url elif service_name == "emr" and _config.config.emr_endpoint_url is not None: endpoint_url = _config.config.emr_endpoint_url + elif service_name == "lakeformation" and _config.config.lakeformation_endpoint_url is not None: + endpoint_url = _config.config.lakeformation_endpoint_url return endpoint_url diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index bca15e15f..dba3888fa 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -761,8 +761,8 @@ def read_sql_query( >>> import awswrangler as wr >>> df = wr.athena.read_sql_query( - ... sql="SELECT * FROM my_table WHERE name=:name;", - ... params={"name": "filtered_name"} + ... sql="SELECT * FROM my_table WHERE name=:name; AND city=:city;", + ... params={"name": "'filtered_name'", "city": "'filtered_city'"} ... ) """ diff --git a/awswrangler/catalog/_create.py b/awswrangler/catalog/_create.py index 2dcbc6fc7..50f7f82d0 100644 --- a/awswrangler/catalog/_create.py +++ b/awswrangler/catalog/_create.py @@ -33,6 +33,7 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements catalog_versioning: bool, boto3_session: Optional[boto3.Session], table_input: Dict[str, Any], + table_type: Optional[str], table_exist: bool, projection_enabled: bool, partitions_types: Optional[Dict[str, str]], @@ -118,7 +119,8 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'." ) if table_exist is True and mode == "overwrite": - delete_all_partitions(table=table, database=database, catalog_id=catalog_id, boto3_session=session) + if table_type != "GOVERNED": + delete_all_partitions(table=table, database=database, catalog_id=catalog_id, boto3_session=session) _logger.debug("Updating table (%s)...", mode) client_glue.update_table( **_catalog_id( @@ -214,6 +216,7 @@ def _create_parquet_table( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str], partitions_types: Optional[Dict[str, str]], bucketing_info: Optional[Tuple[List[str], int]], catalog_id: Optional[str], @@ -253,6 +256,7 @@ def _create_parquet_table( table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, compression=compression, @@ -269,6 +273,7 @@ def _create_parquet_table( catalog_versioning=catalog_versioning, boto3_session=boto3_session, table_input=table_input, + table_type=table_type, table_exist=table_exist, partitions_types=partitions_types, projection_enabled=projection_enabled, @@ -284,8 +289,9 @@ def _create_parquet_table( def _create_csv_table( database: str, table: str, - path: str, + path: Optional[str], columns_types: Dict[str, str], + table_type: Optional[str], partitions_types: Optional[Dict[str, str]], bucketing_info: Optional[Tuple[List[str], int]], description: Optional[str], @@ -324,6 +330,7 @@ def _create_csv_table( table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, compression=compression, @@ -342,6 +349,7 @@ def _create_csv_table( catalog_versioning=catalog_versioning, boto3_session=boto3_session, table_input=table_input, + table_type=table_type, table_exist=table_exist, partitions_types=partitions_types, projection_enabled=projection_enabled, @@ -519,6 +527,7 @@ def create_parquet_table( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str] = None, partitions_types: Optional[Dict[str, str]] = None, bucketing_info: Optional[Tuple[List[str], int]] = None, catalog_id: Optional[str] = None, @@ -550,6 +559,8 @@ def create_parquet_table( Amazon S3 path (e.g. s3://bucket/prefix/). columns_types: Dict[str, str] Dictionary with keys as column names and values as data types (e.g. {'col0': 'bigint', 'col1': 'double'}). + table_type: str, optional + The type of the Glue Table (EXTERNAL_TABLE, GOVERNED...). Set to EXTERNAL_TABLE if None partitions_types: Dict[str, str], optional Dictionary with keys as partition names and values as data types (e.g. {'col2': 'date'}). bucketing_info: Tuple[List[str], int], optional @@ -627,6 +638,7 @@ def create_parquet_table( table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, catalog_id=catalog_id, @@ -653,6 +665,7 @@ def create_csv_table( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str] = None, partitions_types: Optional[Dict[str, str]] = None, bucketing_info: Optional[Tuple[List[str], int]] = None, compression: Optional[str] = None, @@ -686,6 +699,8 @@ def create_csv_table( Amazon S3 path (e.g. s3://bucket/prefix/). columns_types: Dict[str, str] Dictionary with keys as column names and values as data types (e.g. {'col0': 'bigint', 'col1': 'double'}). + table_type: str, optional + The type of the Glue Table (EXTERNAL_TABLE, GOVERNED...). Set to EXTERNAL_TABLE if None partitions_types: Dict[str, str], optional Dictionary with keys as partition names and values as data types (e.g. {'col2': 'date'}). bucketing_info: Tuple[List[str], int], optional @@ -767,6 +782,7 @@ def create_csv_table( table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, catalog_id=catalog_id, diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index 778d428dd..97aea2eac 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -31,6 +31,7 @@ def _parquet_table_definition( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str], partitions_types: Dict[str, str], bucketing_info: Optional[Tuple[List[str], int]], compression: Optional[str], @@ -39,7 +40,7 @@ def _parquet_table_definition( return { "Name": table, "PartitionKeys": [{"Name": cname, "Type": dtype} for cname, dtype in partitions_types.items()], - "TableType": "EXTERNAL_TABLE", + "TableType": "EXTERNAL_TABLE" if table_type is None else table_type, "Parameters": {"classification": "parquet", "compressionType": str(compression).lower(), "typeOfData": "file"}, "StorageDescriptor": { "Columns": [{"Name": cname, "Type": dtype} for cname, dtype in columns_types.items()], @@ -98,8 +99,9 @@ def _parquet_partition_definition( def _csv_table_definition( table: str, - path: str, + path: Optional[str], columns_types: Dict[str, str], + table_type: Optional[str], partitions_types: Dict[str, str], bucketing_info: Optional[Tuple[List[str], int]], compression: Optional[str], @@ -120,7 +122,7 @@ def _csv_table_definition( return { "Name": table, "PartitionKeys": [{"Name": cname, "Type": dtype} for cname, dtype in partitions_types.items()], - "TableType": "EXTERNAL_TABLE", + "TableType": "EXTERNAL_TABLE" if table_type is None else table_type, "Parameters": parameters, "StorageDescriptor": { "Columns": [{"Name": cname, "Type": dtype} for cname, dtype in columns_types.items()], diff --git a/awswrangler/lakeformation/__init__.py b/awswrangler/lakeformation/__init__.py new file mode 100644 index 000000000..8b8c3084e --- /dev/null +++ b/awswrangler/lakeformation/__init__.py @@ -0,0 +1,20 @@ +"""Amazon Lake Formation Module.""" + +from awswrangler.lakeformation._read import read_sql_query, read_sql_table # noqa +from awswrangler.lakeformation._utils import ( # noqa + abort_transaction, + begin_transaction, + commit_transaction, + extend_transaction, + wait_query, +) + +__all__ = [ + "read_sql_query", + "read_sql_table", + "abort_transaction", + "begin_transaction", + "commit_transaction", + "extend_transaction", + "wait_query", +] diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py new file mode 100644 index 000000000..d08c7a5d9 --- /dev/null +++ b/awswrangler/lakeformation/_read.py @@ -0,0 +1,355 @@ +"""Amazon Lake Formation Module gathering all read functions.""" +import concurrent.futures +import itertools +import logging +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import boto3 +import pandas as pd +from pyarrow import NativeFile, RecordBatchStreamReader, Table + +from awswrangler import _data_types, _utils, catalog, exceptions +from awswrangler._config import apply_configs +from awswrangler.catalog._utils import _catalog_id +from awswrangler.lakeformation._utils import abort_transaction, begin_transaction, wait_query + +_logger: logging.Logger = logging.getLogger(__name__) + + +def _execute_query( + query_id: str, + token_work_unit: Tuple[str, int], + categories: Optional[List[str]], + safe: bool, + use_threads: bool, + boto3_session: boto3.Session, +) -> pd.DataFrame: + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=boto3_session) + token, work_unit = token_work_unit + messages: NativeFile = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] + table: Table = RecordBatchStreamReader(messages.read()).read_all() + args: Dict[str, Any] = {} + if table.num_rows > 0: + args = { + "use_threads": use_threads, + "split_blocks": True, + "self_destruct": True, + "integer_object_nulls": False, + "date_as_object": True, + "ignore_metadata": True, + "strings_to_categorical": False, + "categories": categories, + "safe": safe, + "types_mapper": _data_types.pyarrow2pandas_extension, + } + df: pd.DataFrame = _utils.ensure_df_is_mutable(df=table.to_pandas(**args)) + return df + + +def _resolve_sql_query( + query_id: str, + chunked: Optional[bool], + categories: Optional[List[str]], + safe: bool, + use_threads: bool, + boto3_session: boto3.Session, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=boto3_session) + + wait_query(query_id=query_id, boto3_session=boto3_session) + + # The LF Query Engine distributes the load across workers + # Retrieve the tokens and their associated work units until NextToken is '' + # One Token can span multiple work units + # PageSize determines the size of the "Units" array in each call + scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 10} + next_token: str = "init_token" # Dummy token + token_work_units: List[Tuple[str, int]] = [] + while next_token: + response = client_lakeformation.get_work_units(**scan_kwargs) + token_work_units.extend( # [(Token0, WorkUnitId0), (Token0, WorkUnitId1), (Token1, WorkUnitId2) ... ] + [ + (unit["Token"], unit_id) + for unit in response["Units"] + for unit_id in range(unit["WorkUnitIdMin"], unit["WorkUnitIdMax"] + 1) # Max is inclusive + ] + ) + next_token = response.get("NextToken", None) + scan_kwargs["NextToken"] = next_token + + dfs: List[pd.DataFrame] = list() + if use_threads is False: + dfs = list( + _execute_query( + query_id=query_id, + token_work_unit=token_work_unit, + categories=categories, + safe=safe, + use_threads=use_threads, + boto3_session=boto3_session, + ) + for token_work_unit in token_work_units + ) + else: + cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: + dfs = list( + executor.map( + _execute_query, + itertools.repeat(query_id), + token_work_units, + itertools.repeat(categories), + itertools.repeat(safe), + itertools.repeat(use_threads), + itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), + ) + ) + dfs = [df for df in dfs if not df.empty] + if (not chunked) and dfs: + return pd.concat(dfs, sort=False, copy=False, ignore_index=False) + return dfs + + +@apply_configs +def read_sql_query( + sql: str, + database: str, + transaction_id: Optional[str] = None, + query_as_of_time: Optional[str] = None, + catalog_id: Optional[str] = None, + chunked: bool = False, + categories: Optional[List[str]] = None, + safe: bool = True, + use_threads: bool = True, + boto3_session: Optional[boto3.Session] = None, + params: Optional[Dict[str, Any]] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Execute PartiQL query on AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame. + + Note + ---- + ORDER BY operations are not honoured. + i.e. sql="SELECT * FROM my_table ORDER BY my_column" is NOT valid + + Note + ---- + The database must NOT be explicitely defined in the PartiQL statement. + i.e. sql="SELECT * FROM my_table" is valid + but sql="SELECT * FROM my_db.my_table" is NOT valid + + Note + ---- + Pass one of `transaction_id` or `query_as_of_time`, not both. + + Note + ---- + `chunked` argument (memory-friendly): + If set to `True`, return an Iterable of DataFrames instead of a regular DataFrame. + + Parameters + ---------- + sql : str + partiQL query. + database : str + AWS Glue database name + transaction_id : str, optional + The ID of the transaction at which to read the table contents. + Cannot be specified alongside query_as_of_time + query_as_of_time : str, optional + The time as of when to read the table contents. Must be a valid Unix epoch timestamp. + Cannot be specified alongside transaction_id + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + chunked : bool, optional + If `True`, Wrangler returns an Iterable of DataFrames with no guarantee of chunksize. + categories: Optional[List[str]], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. + safe : bool, default True + For certain data types, a cast is needed in order to store the + data in a pandas DataFrame or Series (e.g. timestamps are always + stored as nanoseconds in pandas). This option controls whether it + is a safe cast or not. + use_threads : bool + True to enable concurrent requests, False to disable multiple threads. + When enabled, os.cpu_count() is used as the max number of threads. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session is used if boto3_session receives None. + params: Dict[str, any], optional + Dict of parameters used to format the partiQL query. Only named parameters are supported. + The dict must contain the information in the form {"name": "value"} and the SQL query must contain + `:name`. + + Returns + ------- + Union[pd.DataFrame, Iterator[pd.DataFrame]] + Pandas DataFrame or Generator of Pandas DataFrames if chunked is passed. + + Examples + -------- + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table;", + ... database="my_db", + ... catalog_id="111111111111" + ... ) + + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table LIMIT 10;", + ... database="my_db", + ... transaction_id="1b62811fa3e02c4e5fdbaa642b752030379c4a8a70da1f8732ce6ccca47afdc9" + ... ) + + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table WHERE name=:name; AND city=:city;", + ... database="my_db", + ... query_as_of_time="1611142914", + ... params={"name": "'filtered_name'", "city": "'filtered_city'"} + ... ) + + """ + if transaction_id is not None and query_as_of_time is not None: + raise exceptions.InvalidArgumentCombination( + "Please pass only one of `transaction_id` or `query_as_of_time`, not both" + ) + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + if params is None: + params = {} + for key, value in params.items(): + sql = sql.replace(f":{key};", str(value)) + + args: Dict[str, Optional[str]] = _catalog_id(catalog_id=catalog_id, **{"DatabaseName": database, "Statement": sql}) + if query_as_of_time: + args["QueryAsOfTime"] = query_as_of_time + elif transaction_id: + args["TransactionId"] = transaction_id + else: + _logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, beginning transaction") + transaction_id = begin_transaction(read_only=True, boto3_session=session) + args["TransactionId"] = transaction_id + query_id: str = client_lakeformation.plan_query(**args)["QueryId"] + try: + return _resolve_sql_query( + query_id=query_id, + chunked=chunked, + categories=categories, + safe=safe, + use_threads=use_threads, + boto3_session=session, + ) + except Exception as ex: + _logger.debug("Aborting transaction with ID: %s.", transaction_id) + if transaction_id: + abort_transaction(transaction_id=transaction_id, boto3_session=session) + _logger.error(ex) + raise + + +@apply_configs +def read_sql_table( + table: str, + database: str, + transaction_id: Optional[str] = None, + query_as_of_time: Optional[str] = None, + catalog_id: Optional[str] = None, + chunked: bool = False, + categories: Optional[List[str]] = None, + safe: bool = True, + use_threads: bool = True, + boto3_session: Optional[boto3.Session] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Extract all rows from AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame. + + Note + ---- + ORDER BY operations are not honoured. + i.e. sql="SELECT * FROM my_table ORDER BY my_column" is NOT valid + + Note + ---- + Pass one of `transaction_id` or `query_as_of_time`, not both. + + Note + ---- + `chunked` argument (memory-friendly): + If set to `True`, return an Iterable of DataFrames instead of a regular DataFrame. + + Parameters + ---------- + table : str + AWS Glue table name. + database : str + AWS Glue database name + transaction_id : str, optional + The ID of the transaction at which to read the table contents. + Cannot be specified alongside query_as_of_time + query_as_of_time : str, optional + The time as of when to read the table contents. Must be a valid Unix epoch timestamp. + Cannot be specified alongside transaction_id + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + chunked : bool, optional + If `True`, Wrangler returns an Iterable of DataFrames with no guarantee of chunksize. + categories: Optional[List[str]], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. + safe : bool, default True + For certain data types, a cast is needed in order to store the + data in a pandas DataFrame or Series (e.g. timestamps are always + stored as nanoseconds in pandas). This option controls whether it + is a safe cast or not. + use_threads : bool + True to enable concurrent requests, False to disable multiple threads. + When enabled, os.cpu_count() is used as the max number of threads. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session is used if boto3_session receives None. + + Returns + ------- + Union[pd.DataFrame, Iterator[pd.DataFrame]] + Pandas DataFrame or Generator of Pandas DataFrames if chunked is passed. + + Examples + -------- + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_table( + ... table="my_table", + ... database="my_db", + ... catalog_id="111111111111", + ... ) + + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_table( + ... table="my_table", + ... database="my_db", + ... transaction_id="1b62811fa3e02c4e5fdbaa642b752030379c4a8a70da1f8732ce6ccca47afdc9", + ... chunked=True, + ... ) + + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_table( + ... table="my_table", + ... database="my_db", + ... query_as_of_time="1611142914", + ... use_threads=True, + ... ) + + """ + table = catalog.sanitize_table_name(table=table) + return read_sql_query( + sql=f"SELECT * FROM {table}", + database=database, + transaction_id=transaction_id, + query_as_of_time=query_as_of_time, + safe=safe, + catalog_id=catalog_id, + categories=categories, + chunked=chunked, + use_threads=use_threads, + boto3_session=boto3_session, + ) diff --git a/awswrangler/lakeformation/_utils.py b/awswrangler/lakeformation/_utils.py new file mode 100644 index 000000000..ea94101bd --- /dev/null +++ b/awswrangler/lakeformation/_utils.py @@ -0,0 +1,264 @@ +"""Utilities Module for Amazon Lake Formation.""" +import logging +import time +from typing import Any, Dict, List, Optional, Union + +import boto3 + +from awswrangler import _utils, exceptions +from awswrangler.catalog._utils import _catalog_id +from awswrangler.s3._describe import describe_objects + +_QUERY_FINAL_STATES: List[str] = ["ERROR", "FINISHED"] +_QUERY_WAIT_POLLING_DELAY: float = 2 # SECONDS + +_logger: logging.Logger = logging.getLogger(__name__) + + +def _without_keys(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: + return {x: d[x] for x in d if x not in keys} + + +def _build_partition_predicate( + partition_cols: List[str], + partitions_types: Dict[str, str], + partitions_values: List[str], +) -> str: + partition_predicates: List[str] = [] + for col, val in zip(partition_cols, partitions_values): + if partitions_types[col].startswith(("tinyint", "smallint", "int", "bigint", "float", "double", "decimal")): + partition_predicates.append(f"{col}={str(val)}") + else: + partition_predicates.append(f"{col}='{str(val)}'") + return " AND ".join(partition_predicates) + + +def _build_table_objects( + paths: List[str], + partitions_values: Dict[str, List[str]], + use_threads: bool, + boto3_session: Optional[boto3.Session], +) -> List[Dict[str, Any]]: + table_objects: List[Dict[str, Any]] = [] + paths_desc: Dict[str, Dict[str, Any]] = describe_objects( + path=paths, use_threads=use_threads, boto3_session=boto3_session + ) + for path, path_desc in paths_desc.items(): + table_object: Dict[str, Any] = { + "Uri": path, + "ETag": path_desc["ETag"], + "Size": path_desc["ContentLength"], + } + if partitions_values: + table_object["PartitionValues"] = partitions_values[f"{path.rsplit('/', 1)[0].rstrip('/')}/"] + table_objects.append(table_object) + return table_objects + + +def _get_table_objects( + catalog_id: Optional[str], + database: str, + table: str, + transaction_id: str, + boto3_session: Optional[boto3.Session], + partition_cols: Optional[List[str]] = None, + partitions_types: Optional[Dict[str, str]] = None, + partitions_values: Optional[List[str]] = None, +) -> List[Dict[str, Any]]: + """Get Governed Table Objects from Lake Formation Engine.""" + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + scan_kwargs: Dict[str, Union[str, int]] = _catalog_id( + catalog_id=catalog_id, TransactionId=transaction_id, DatabaseName=database, TableName=table, MaxResults=100 + ) + if partition_cols and partitions_types and partitions_values: + scan_kwargs["PartitionPredicate"] = _build_partition_predicate( + partition_cols=partition_cols, partitions_types=partitions_types, partitions_values=partitions_values + ) + + next_token: str = "init_token" # Dummy token + table_objects: List[Dict[str, Any]] = [] + while next_token: + response = client_lakeformation.get_table_objects(**scan_kwargs) + for objects in response["Objects"]: + for table_object in objects["Objects"]: + if objects["PartitionValues"]: + table_object["PartitionValues"] = objects["PartitionValues"] + table_objects.append(table_object) + next_token = response.get("NextToken", None) + scan_kwargs["NextToken"] = next_token + return table_objects + + +def _update_table_objects( + catalog_id: Optional[str], + database: str, + table: str, + transaction_id: str, + boto3_session: Optional[boto3.Session], + add_objects: Optional[List[Dict[str, Any]]] = None, + del_objects: Optional[List[Dict[str, Any]]] = None, +) -> None: + """Register Governed Table Objects changes to Lake Formation Engine.""" + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + update_kwargs: Dict[str, Union[str, int, List[Dict[str, Dict[str, Any]]]]] = _catalog_id( + catalog_id=catalog_id, TransactionId=transaction_id, DatabaseName=database, TableName=table + ) + + write_operations: List[Dict[str, Dict[str, Any]]] = [] + if add_objects: + write_operations.extend({"AddObject": obj} for obj in add_objects) + elif del_objects: + write_operations.extend({"DeleteObject": _without_keys(obj, ["Size"])} for obj in del_objects) + update_kwargs["WriteOperations"] = write_operations + + client_lakeformation.update_table_objects(**update_kwargs) + + +def abort_transaction(transaction_id: str, boto3_session: Optional[boto3.Session] = None) -> None: + """Abort the specified transaction. Returns exception if the transaction was previously committed. + + Parameters + ---------- + transaction_id : str + The ID of the transaction. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + None + None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.lakeformation.abort_transaction(transaction_id="...") + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + client_lakeformation.abort_transaction(TransactionId=transaction_id) + + +def begin_transaction(read_only: Optional[bool] = False, boto3_session: Optional[boto3.Session] = None) -> str: + """Start a new transaction and returns its transaction ID. + + Parameters + ---------- + read_only : bool, optional + Indicates that that this transaction should be read only. + Writes made using a read-only transaction ID will be rejected. + Read-only transactions do not need to be committed. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + str + An opaque identifier for the transaction. + + Examples + -------- + >>> import awswrangler as wr + >>> transaction_id = wr.lakeformation.begin_transaction(read_only=False) + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + transaction_id: str = client_lakeformation.begin_transaction(ReadOnly=read_only)["TransactionId"] + return transaction_id + + +def commit_transaction(transaction_id: str, boto3_session: Optional[boto3.Session] = None) -> None: + """Commit the specified transaction. Returns exception if the transaction was previously aborted. + + Parameters + ---------- + transaction_id : str + The ID of the transaction. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + None + None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.lakeformation.commit_transaction(transaction_id="...") + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + client_lakeformation.commit_transaction(TransactionId=transaction_id) + + +def extend_transaction(transaction_id: str, boto3_session: Optional[boto3.Session] = None) -> None: + """Indicate to the service that the specified transaction is still active and should not be aborted. + + Parameters + ---------- + transaction_id : str + The ID of the transaction. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + None + None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.lakeformation.extend_transaction(transaction_id="...") + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + client_lakeformation.extend_transaction(TransactionId=transaction_id) + + +def wait_query(query_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]: + """Wait for the query to end. + + Parameters + ---------- + query_id : str + Lake Formation query execution ID. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + Dict[str, Any] + Dictionary with the get_query_state response. + + Examples + -------- + >>> import awswrangler as wr + >>> res = wr.lakeformation.wait_query(query_id='query-id') + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + response: Dict[str, Any] = client_lakeformation.get_query_state(QueryId=query_id) + state: str = response["State"] + while state not in _QUERY_FINAL_STATES: + time.sleep(_QUERY_WAIT_POLLING_DELAY) + response = client_lakeformation.get_query_state(QueryId=query_id) + state = response["State"] + _logger.debug("state: %s", state) + if state == "ERROR": + raise exceptions.QueryFailed(response.get("Error")) + return response diff --git a/awswrangler/s3/_write.py b/awswrangler/s3/_write.py index e94a71288..666035bb6 100644 --- a/awswrangler/s3/_write.py +++ b/awswrangler/s3/_write.py @@ -47,7 +47,7 @@ def _validate_args( table: Optional[str], database: Optional[str], dataset: bool, - path: str, + path: Optional[str], partition_cols: Optional[List[str]], bucketing_info: Optional[Tuple[List[str], int]], mode: Optional[str], @@ -58,6 +58,8 @@ def _validate_args( if df.empty is True: raise exceptions.EmptyDataFrame() if dataset is False: + if path is None: + raise exceptions.InvalidArgumentValue("If dataset is False, the argument `path` must be passed.") if path.endswith("/"): raise exceptions.InvalidArgumentValue( "If , the argument should be a file path, not a directory." diff --git a/awswrangler/s3/_write_dataset.py b/awswrangler/s3/_write_dataset.py index bf3a7a1f4..3bc05cf2d 100644 --- a/awswrangler/s3/_write_dataset.py +++ b/awswrangler/s3/_write_dataset.py @@ -9,6 +9,14 @@ import pandas as pd from awswrangler import exceptions +from awswrangler.lakeformation._utils import ( + _build_table_objects, + _get_table_objects, + _update_table_objects, + abort_transaction, + begin_transaction, + commit_transaction, +) from awswrangler.s3._delete import delete_objects from awswrangler.s3._write_concurrent import _WriteProxy @@ -23,6 +31,12 @@ def _to_partitions( use_threads: bool, mode: str, partition_cols: List[str], + partitions_types: Optional[Dict[str, str]], + catalog_id: Optional[str], + database: Optional[str], + table: Optional[str], + table_type: Optional[str], + transaction_id: Optional[str], bucketing_info: Optional[Tuple[List[str], int]], boto3_session: boto3.Session, **func_kwargs: Any, @@ -37,12 +51,33 @@ def _to_partitions( subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)]) prefix: str = f"{path_root}{subdir}/" if mode == "overwrite_partitions": - delete_objects( - path=prefix, - use_threads=use_threads, - boto3_session=boto3_session, - s3_additional_kwargs=func_kwargs.get("s3_additional_kwargs"), - ) + if (table_type == "GOVERNED") and (table is not None) and (database is not None): + del_objects: List[Dict[str, Any]] = _get_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, # type: ignore + partition_cols=partition_cols, + partitions_values=keys, + partitions_types=partitions_types, + boto3_session=boto3_session, + ) + if del_objects: + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, # type: ignore + del_objects=del_objects, + boto3_session=boto3_session, + ) + else: + delete_objects( + path=prefix, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=func_kwargs.get("s3_additional_kwargs"), + ) if bucketing_info: _to_buckets( func=func, @@ -137,24 +172,51 @@ def _to_dataset( use_threads: bool, mode: str, partition_cols: Optional[List[str]], + partitions_types: Optional[Dict[str, str]], + catalog_id: Optional[str], + database: Optional[str], + table: Optional[str], + table_type: Optional[str], + transaction_id: Optional[str], bucketing_info: Optional[Tuple[List[str], int]], boto3_session: boto3.Session, **func_kwargs: Any, ) -> Tuple[List[str], Dict[str, List[str]]]: path_root = path_root if path_root.endswith("/") else f"{path_root}/" + commit_trans: bool = False + if table_type == "GOVERNED": + # Check whether to skip committing the transaction (i.e. multiple read/write operations) + if transaction_id is None: + _logger.debug("`transaction_id` not specified, beginning transaction") + transaction_id = begin_transaction(read_only=False, boto3_session=boto3_session) + commit_trans = True + # Evaluate mode if mode not in ["append", "overwrite", "overwrite_partitions"]: raise exceptions.InvalidArgumentValue( f"{mode} is a invalid mode, please use append, overwrite or overwrite_partitions." ) if (mode == "overwrite") or ((mode == "overwrite_partitions") and (not partition_cols)): - delete_objects( - path=path_root, - use_threads=use_threads, - boto3_session=boto3_session, - s3_additional_kwargs=func_kwargs.get("s3_additional_kwargs"), - ) + if (table_type == "GOVERNED") and (table is not None) and (database is not None): + del_objects: List[Dict[str, Any]] = _get_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, # type: ignore + boto3_session=boto3_session, + ) + if del_objects: + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, # type: ignore + del_objects=del_objects, + boto3_session=boto3_session, + ) + else: + delete_objects(path=path_root, use_threads=use_threads, boto3_session=boto3_session) # Writing partitions_values: Dict[str, List[str]] = {} @@ -167,8 +229,14 @@ def _to_dataset( path_root=path_root, use_threads=use_threads, mode=mode, + catalog_id=catalog_id, + database=database, + table=table, + table_type=table_type, + transaction_id=transaction_id, bucketing_info=bucketing_info, partition_cols=partition_cols, + partitions_types=partitions_types, boto3_session=boto3_session, index=index, **func_kwargs, @@ -190,4 +258,27 @@ def _to_dataset( ) _logger.debug("paths: %s", paths) _logger.debug("partitions_values: %s", partitions_values) + if (table_type == "GOVERNED") and (table is not None) and (database is not None): + add_objects: List[Dict[str, Any]] = _build_table_objects( + paths, partitions_values, use_threads=use_threads, boto3_session=boto3_session + ) + try: + if add_objects: + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, # type: ignore + add_objects=add_objects, + boto3_session=boto3_session, + ) + if commit_trans: + commit_transaction(transaction_id=transaction_id, boto3_session=boto3_session) # type: ignore + except Exception as ex: + _logger.debug("Aborting transaction with ID: %s.", transaction_id) + if transaction_id: + abort_transaction(transaction_id=transaction_id, boto3_session=boto3_session) + _logger.error(ex) + raise + return paths, partitions_values diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index 5e9311c2b..5ea2ff9c0 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -196,9 +196,9 @@ def _to_parquet( @apply_configs -def to_parquet( # pylint: disable=too-many-arguments,too-many-locals +def to_parquet( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements df: pd.DataFrame, - path: str, + path: Optional[str] = None, index: bool = False, compression: Optional[str] = "snappy", max_rows_by_file: Optional[int] = None, @@ -215,6 +215,8 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals schema_evolution: bool = True, database: Optional[str] = None, table: Optional[str] = None, + table_type: Optional[str] = None, + transaction_id: Optional[str] = None, dtype: Optional[Dict[str, str]] = None, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, @@ -252,7 +254,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals ---------- df: pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html - path : str + path : str, optional S3 path (for file e.g. ``s3://bucket/prefix/filename.parquet``) (for dataset e.g. ``s3://bucket/prefix``). index : bool True to store the DataFrame index in file, otherwise False to ignore it. @@ -306,6 +308,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals Glue/Athena catalog: Database name. table : str, optional Glue/Athena catalog: Table name. + table_type: str, optional + The type of the Glue Table. Set to EXTERNAL_TABLE if None. + transaction_id: str, optional + The ID of the transaction when writing to a Governed Table. dtype : Dict[str, str], optional Dictionary of columns names and Athena/Glue types to be casted. Useful when you have columns with undetermined or mixed data types. @@ -451,6 +457,28 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals } } + Writing dataset to Glue governed table + + >>> import awswrangler as wr + >>> import pandas as pd + >>> wr.s3.to_parquet( + ... df=pd.DataFrame({ + ... 'col': [1, 2, 3], + ... 'col2': ['A', 'A', 'B'], + ... 'col3': [None, None, None] + ... }), + ... dataset=True, + ... mode='append', + ... database='default', # Athena/Glue database + ... table='my_table', # Athena/Glue table + ... table_type='GOVERNED', + ... transaction_id="xxx", + ... ) + { + 'paths': ['s3://.../x.parquet'], + 'partitions_values: {} + } + Writing dataset casting empty column data type >>> import awswrangler as wr @@ -497,6 +525,8 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals dtype = dtype if dtype else {} partitions_values: Dict[str, List[str]] = {} mode = "append" if mode is None else mode + if transaction_id: + table_type = "GOVERNED" cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) session: boto3.Session = _utils.ensure_session(session=boto3_session) @@ -510,6 +540,15 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access database=database, table=table, boto3_session=session, catalog_id=catalog_id ) + if catalog_table_input: + table_type = catalog_table_input["TableType"] + if path is None: + if catalog_table_input: + path = catalog_table_input["StorageDescriptor"]["Location"] + else: + raise exceptions.InvalidArgumentValue( + "Glue table does not exist. Please pass the `path` argument to create it." + ) df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode) schema: pa.Schema = _data_types.pyarrow_schema_from_pandas( df=df, index=index, ignore_cols=partition_cols, dtype=dtype @@ -540,17 +579,51 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals ) if schema_evolution is False: _check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode) + if (catalog_table_input is None) and (table_type == "GOVERNED"): + catalog._create_parquet_table( # pylint: disable=protected-access + database=database, + table=table, + path=path, # type: ignore + columns_types=columns_types, + table_type=table_type, + partitions_types=partitions_types, + bucketing_info=bucketing_info, + compression=compression, + description=description, + parameters=parameters, + columns_comments=columns_comments, + boto3_session=session, + mode=mode, + catalog_versioning=catalog_versioning, + projection_enabled=projection_enabled, + projection_types=projection_types, + projection_ranges=projection_ranges, + projection_values=projection_values, + projection_intervals=projection_intervals, + projection_digits=projection_digits, + catalog_id=catalog_id, + catalog_table_input=catalog_table_input, + ) + catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access + database=database, table=table, boto3_session=session, catalog_id=catalog_id + ) paths, partitions_values = _to_dataset( func=_to_parquet, concurrent_partitioning=concurrent_partitioning, df=df, - path_root=path, + path_root=path, # type: ignore index=index, compression=compression, compression_ext=compression_ext, + catalog_id=catalog_id, + database=database, + table=table, + table_type=table_type, + transaction_id=transaction_id, cpus=cpus, use_threads=use_threads, partition_cols=partition_cols, + partitions_types=partitions_types, bucketing_info=bucketing_info, dtype=dtype, mode=mode, @@ -564,8 +637,9 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals catalog._create_parquet_table( # pylint: disable=protected-access database=database, table=table, - path=path, + path=path, # type: ignore columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, compression=compression, @@ -584,7 +658,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals catalog_id=catalog_id, catalog_table_input=catalog_table_input, ) - if partitions_values and (regular_partitions is True): + if partitions_values and (regular_partitions is True) and (table_type != "GOVERNED"): _logger.debug("partitions_values:\n%s", partitions_values) catalog.add_parquet_partitions( database=database, diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index d8e8d2adb..75d9324e4 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -72,9 +72,9 @@ def _to_text( @apply_configs -def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements +def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches df: pd.DataFrame, - path: str, + path: Optional[str] = None, sep: str = ",", index: bool = True, columns: Optional[List[str]] = None, @@ -90,6 +90,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state catalog_versioning: bool = False, database: Optional[str] = None, table: Optional[str] = None, + table_type: Optional[str] = None, + transaction_id: Optional[str] = None, dtype: Optional[Dict[str, str]] = None, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, @@ -137,7 +139,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state ---------- df: pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html - path : str + path : str, optional Amazon S3 path (e.g. s3://bucket/filename.csv). sep : str String of length 1. Field delimiter for the output file. @@ -183,6 +185,10 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state Glue/Athena catalog: Database name. table : str, optional Glue/Athena catalog: Table name. + table_type: str, optional + The type of the Glue Table. Set to EXTERNAL_TABLE if None + transaction_id: str, optional + The ID of the transaction when writing to a Governed Table. dtype : Dict[str, str], optional Dictionary of columns names and Athena/Glue types to be casted. Useful when you have columns with undetermined or mixed data types. @@ -349,6 +355,28 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state } } + Writing dataset to Glue governed table + + >>> import awswrangler as wr + >>> import pandas as pd + >>> wr.s3.to_csv( + ... df=pd.DataFrame({ + ... 'col': [1, 2, 3], + ... 'col2': ['A', 'A', 'B'], + ... 'col3': [None, None, None] + ... }), + ... dataset=True, + ... mode='append', + ... database='default', # Athena/Glue database + ... table='my_table', # Athena/Glue table + ... table_type='GOVERNED', + ... transaction_id="xxx", + ... ) + { + 'paths': ['s3://.../x.csv'], + 'partitions_values: {} + } + Writing dataset casting empty column data type >>> import awswrangler as wr @@ -401,6 +429,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state dtype = dtype if dtype else {} partitions_values: Dict[str, List[str]] = {} mode = "append" if mode is None else mode + if transaction_id: + table_type = "GOVERNED" session: boto3.Session = _utils.ensure_session(session=boto3_session) # Sanitize table to respect Athena's standards @@ -413,6 +443,15 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access database=database, table=table, boto3_session=session, catalog_id=catalog_id ) + if catalog_table_input: + table_type = catalog_table_input["TableType"] + if path is None: + if catalog_table_input: + path = catalog_table_input["StorageDescriptor"]["Location"] + else: + raise exceptions.InvalidArgumentValue( + "Glue table does not exist. Please pass the `path` argument to create it." + ) if pandas_kwargs.get("compression") not in ("gzip", "bz2", None): raise exceptions.InvalidArgumentCombination( "If database and table are given, you must use one of these compressions: gzip, bz2 or None." @@ -420,6 +459,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode) + paths: List[str] = [] if dataset is False: pandas_kwargs["sep"] = sep pandas_kwargs["index"] = index @@ -433,7 +473,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state s3_additional_kwargs=s3_additional_kwargs, **pandas_kwargs, ) - paths = [path] + paths = [path] # type: ignore else: if database and table: quoting: Optional[int] = csv.QUOTE_NONE @@ -456,16 +496,58 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state pd_kwargs.pop("compression", None) df = df[columns] if columns else df + columns_types: Dict[str, str] = {} + partitions_types: Dict[str, str] = {} + if database and table: + columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( + df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True + ) + if (catalog_table_input is None) and (table_type == "GOVERNED"): + catalog._create_csv_table( # pylint: disable=protected-access + database=database, + table=table, + path=path, + columns_types=columns_types, + table_type=table_type, + partitions_types=partitions_types, + bucketing_info=bucketing_info, + description=description, + parameters=parameters, + columns_comments=columns_comments, + boto3_session=session, + mode=mode, + catalog_versioning=catalog_versioning, + sep=sep, + projection_enabled=projection_enabled, + projection_types=projection_types, + projection_ranges=projection_ranges, + projection_values=projection_values, + projection_intervals=projection_intervals, + projection_digits=projection_digits, + catalog_table_input=catalog_table_input, + catalog_id=catalog_id, + compression=pandas_kwargs.get("compression"), + skip_header_line_count=None, + ) + catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access + database=database, table=table, boto3_session=session, catalog_id=catalog_id + ) paths, partitions_values = _to_dataset( func=_to_text, concurrent_partitioning=concurrent_partitioning, df=df, - path_root=path, + path_root=path, # type: ignore index=index, sep=sep, compression=compression, + catalog_id=catalog_id, + database=database, + table=table, + table_type=table_type, + transaction_id=transaction_id, use_threads=use_threads, partition_cols=partition_cols, + partitions_types=partitions_types, bucketing_info=bucketing_info, mode=mode, boto3_session=session, @@ -479,14 +561,12 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state ) if database and table: try: - columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( - df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True - ) catalog._create_csv_table( # pylint: disable=protected-access database=database, table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, description=description, @@ -507,7 +587,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state compression=pandas_kwargs.get("compression"), skip_header_line_count=None, ) - if partitions_values and (regular_partitions is True): + if partitions_values and (regular_partitions is True) and (table_type != "GOVERNED"): _logger.debug("partitions_values:\n%s", partitions_values) catalog.add_csv_partitions( database=database, diff --git a/cloudformation/base.yaml b/cloudformation/base.yaml index 6e77560d4..76b69acff 100644 --- a/cloudformation/base.yaml +++ b/cloudformation/base.yaml @@ -1,6 +1,19 @@ AWSTemplateFormatVersion: 2010-09-09 Description: | AWS Data Wrangler Development Base Data Lake Infrastructure. VPC, Subnets, S3 Bucket, Glue Database, etc. +Parameters: + EnableLakeFormation: + Type: String + Description: set to True if Lake Formation is enabled in the account + Default: false + AllowedValues: + - true + - false +Conditions: + CreateLFResources: + Fn::Equals: + - Ref: EnableLakeFormation + - true Resources: VPC: Type: AWS::EC2::VPC @@ -161,6 +174,7 @@ Resources: - Key: Env Value: aws-data-wrangler Description: Aws Data Wrangler Test Key. + EnableKeyRotation: true KeyPolicy: Version: '2012-10-17' Id: aws-data-wrangler-key @@ -217,7 +231,99 @@ Resources: Ref: AWS::AccountId DatabaseInput: Name: aws_data_wrangler - Description: AWS Data Wrangler Test Arena - Glue Database + Description: AWS Data Wrangler Test Athena - Glue Database + LakeFormationBucket: + Type: AWS::S3::Bucket + Condition: CreateLFResources + Properties: + Tags: + - Key: Env + Value: aws-data-wrangler + PublicAccessBlockConfiguration: + BlockPublicAcls: true + BlockPublicPolicy: true + IgnorePublicAcls: true + RestrictPublicBuckets: true + LifecycleConfiguration: + Rules: + - Id: CleaningUp + Status: Enabled + ExpirationInDays: 1 + AbortIncompleteMultipartUpload: + DaysAfterInitiation: 1 + NoncurrentVersionExpirationInDays: 1 + LakeFormationTransactionRole: + Type: AWS::IAM::Role + Condition: CreateLFResources + Properties: + Tags: + - Key: Env + Value: aws-data-wrangler + AssumeRolePolicyDocument: + Version: 2012-10-17 + Statement: + - Effect: Allow + Principal: + Service: + - lakeformation.amazonaws.com + Action: + - sts:AssumeRole + Path: / + Policies: + - PolicyName: Root + PolicyDocument: + Version: 2012-10-17 + Statement: + - Effect: Allow + Action: + - s3:DeleteObject + - s3:GetObject + - s3:PutObject + Resource: + - Fn::Sub: arn:aws:s3:::${LakeFormationBucket}/* + - Effect: Allow + Action: + - s3:ListObject + Resource: + - Fn::Sub: arn:aws:s3:::${LakeFormationBucket} + - Effect: Allow + Action: + - execute-api:Invoke + Resource: arn:aws:execute-api:*:*:*/*/POST/reportStatus + - Effect: Allow + Action: + - lakeformation:AbortTransaction + - lakeformation:BeginTransaction + - lakeformation:CommitTransaction + - lakeformation:GetTableObjects + - lakeformation:UpdateTableObjects + Resource: '*' + - Effect: Allow + Action: + - glue:GetTable + - glue:GetPartitions + - glue:UpdateTable + Resource: '*' + LakeFormationBucketS3Registration: + Type: AWS::LakeFormation::Resource + Condition: CreateLFResources + Properties: + ResourceArn: + Fn::Sub: arn:aws:::s3:${LakeFormationBucket}/ + RoleArn: + Fn::GetAtt: + - LakeFormationTransactionRole + - Arn + UseServiceLinkedRole: false + LakeFormationGlueDatabase: + Type: AWS::Glue::Database + Condition: CreateLFResources + Properties: + CatalogId: + Ref: AWS::AccountId + DatabaseInput: + Name: aws_data_wrangler_lakeformation + Description: AWS Data Wrangler - Lake Formation Database LogGroup: Type: AWS::Logs::LogGroup Properties: @@ -274,6 +380,11 @@ Outputs: Value: Ref: GlueDatabase Description: Glue Database Name. + LakeFormationGlueDatabaseName: + Condition: CreateLFResources + Value: + Ref: LakeFormationGlueDatabase + Description: Lake Formation Glue Database Name. LogGroupName: Value: Ref: LogGroup diff --git a/tests/_utils.py b/tests/_utils.py index 85df69484..c931445c2 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -17,7 +17,7 @@ CFN_VALID_STATUS = ["CREATE_COMPLETE", "ROLLBACK_COMPLETE", "UPDATE_COMPLETE", "UPDATE_ROLLBACK_COMPLETE"] -def get_df(): +def get_df(governed=False): df = pd.DataFrame( { "iint8": [1, None, 2], @@ -45,10 +45,13 @@ def get_df(): df["float"] = df["float"].astype("float32") df["string"] = df["string"].astype("string") df["category"] = df["category"].astype("category") + + if governed: + df = df.drop(["iint8", "binary"], axis=1) # tinyint & binary currently not supported return df -def get_df_list(): +def get_df_list(governed=False): df = pd.DataFrame( { "iint8": [1, None, 2], @@ -79,10 +82,13 @@ def get_df_list(): df["float"] = df["float"].astype("float32") df["string"] = df["string"].astype("string") df["category"] = df["category"].astype("category") + + if governed: + df = (df.drop(["iint8", "binary"], axis=1),) # tinyint & binary currently not supported return df -def get_df_cast(): +def get_df_cast(governed=False): df = pd.DataFrame( { "iint8": [None, None, None], @@ -103,6 +109,8 @@ def get_df_cast(): "par1": ["a", "b", "b"], } ) + if governed: + df = (df.drop(["iint8", "binary"], axis=1),) # tinyint & binary currently not supported return df @@ -418,7 +426,7 @@ def get_query_long(): """ -def ensure_data_types(df, has_list=False): +def ensure_data_types(df, has_list=False, governed=False): if "iint8" in df.columns: assert str(df["iint8"].dtype).startswith("Int") assert str(df["iint16"].dtype).startswith("Int") @@ -430,7 +438,10 @@ def ensure_data_types(df, has_list=False): if "string_object" in df.columns: assert str(df["string_object"].dtype) == "string" assert str(df["string"].dtype) == "string" - assert str(df["date"].dtype) == "object" + if governed: + assert str(df["date"].dtype) == "datetime64[ns]" + else: + assert str(df["date"].dtype) == "object" assert str(df["timestamp"].dtype) == "datetime64[ns]" assert str(df["bool"].dtype) in ("boolean", "Int64", "object") if "binary" in df.columns: @@ -447,7 +458,10 @@ def ensure_data_types(df, has_list=False): if not row.empty: row = row.iloc[0] assert str(type(row["decimal"]).__name__) == "Decimal" - assert str(type(row["date"]).__name__) == "date" + if governed: + assert str(type(row["date"]).__name__) == "Timestamp" + else: + assert str(type(row["date"]).__name__) == "date" if "binary" in df.columns: assert str(type(row["binary"]).__name__) == "bytes" if has_list is True: @@ -468,7 +482,7 @@ def ensure_data_types_category(df): assert str(df["par1"].dtype) == "category" -def ensure_data_types_csv(df): +def ensure_data_types_csv(df, governed=False): if "__index_level_0__" in df: assert str(df["__index_level_0__"].dtype).startswith("Int") assert str(df["id"].dtype).startswith("Int") @@ -480,7 +494,10 @@ def ensure_data_types_csv(df): assert str(df["float"].dtype).startswith("float") if "int" in df: assert str(df["int"].dtype).startswith("Int") - assert str(df["date"].dtype) == "object" + if governed: + assert str(df["date"].dtype).startswith("datetime") + else: + assert str(df["date"].dtype) == "object" assert str(df["timestamp"].dtype).startswith("datetime") if "bool" in df: assert str(df["bool"].dtype) == "boolean" diff --git a/tests/conftest.py b/tests/conftest.py index 011fccfca..7bdb19b64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,11 @@ def glue_database(cloudformation_outputs): return cloudformation_outputs["GlueDatabaseName"] +@pytest.fixture(scope="session") +def lakeformation_glue_database(cloudformation_outputs): + return cloudformation_outputs["LakeFormationGlueDatabaseName"] + + @pytest.fixture(scope="session") def kms_key(cloudformation_outputs): return cloudformation_outputs["KmsKeyArn"] diff --git a/tests/test__routines.py b/tests/test__routines.py index fb08e8d12..96f430059 100644 --- a/tests/test__routines.py +++ b/tests/test__routines.py @@ -10,7 +10,13 @@ @pytest.mark.parametrize("use_threads", [True, False]) @pytest.mark.parametrize("concurrent_partitioning", [True, False]) -def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_partitioning): +@pytest.mark.parametrize("table_type", ["EXTERNAL_TABLE", "GOVERNED"]) +def test_routine_0( + lakeformation_glue_database, glue_database, glue_table, table_type, path, use_threads, concurrent_partitioning +): + + table = f"__{glue_table}" + database = lakeformation_glue_database if table_type == "GOVERNED" else glue_database # Round 1 - Warm up df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") @@ -19,24 +25,28 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part path=path, dataset=True, mode="overwrite", - database=glue_database, - table=glue_table, + database=database, + table=table, + table_type=table_type, description="c0", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c0": "0"}, use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(glue_database, glue_table) == "c0" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c0" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c0"] == "0" @@ -44,27 +54,29 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part df = pd.DataFrame({"c1": [None, 1, None]}, dtype="Int16") wr.s3.to_parquet( df=df, - path=path, dataset=True, mode="overwrite", - database=glue_database, - table=glue_table, + database=database, + table=table, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c1": "1"}, use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(glue_database, glue_table) == "c1" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c1" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" @@ -75,25 +87,28 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part path=path, dataset=True, mode="append", - database=glue_database, - table=glue_table, + database=database, + table=table, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)}, columns_comments={"c1": "1"}, use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df.columns) == len(df2.columns) assert len(df.index) * 2 == len(df2.index) assert df.c1.sum() + 1 == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(glue_database, glue_table) == "c1" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c1" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" @@ -101,28 +116,30 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part df = pd.DataFrame({"c2": ["a", None, "b"], "c1": [None, None, None]}) wr.s3.to_parquet( df=df, - path=path, dataset=True, mode="append", - database=glue_database, - table=glue_table, + database=database, + table=table, description="c1+c2", parameters={"num_cols": "2", "num_rows": "9"}, columns_comments={"c1": "1", "c2": "2"}, use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df2.columns) == 2 assert len(df2.index) == 9 assert df2.c1.sum() == 3 - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "9" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c1+c2" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c1+c2" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" assert comments["c2"] == "2" @@ -134,39 +151,56 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part path=path, dataset=True, mode="append", - database=glue_database, - table=glue_table, + database=database, + table=table, description="c1+c2+c3", parameters={"num_cols": "3", "num_rows": "10"}, columns_comments={"c1": "1!", "c2": "2!", "c3": "3"}, use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df2.columns) == 3 assert len(df2.index) == 10 assert df2.c1.sum() == 4 - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "3" assert parameters["num_rows"] == "10" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c1+c2+c3" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c1+c2+c3" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c1"] == "1!" assert comments["c2"] == "2!" assert comments["c3"] == "3" - # Round 6 - Overwrite Partitioned + wr.catalog.delete_table_if_exists(database=database, table=table) + + +@pytest.mark.parametrize("use_threads", [True, False]) +@pytest.mark.parametrize("concurrent_partitioning", [True, False]) +@pytest.mark.parametrize("table_type", ["EXTERNAL_TABLE", "GOVERNED"]) +def test_routine_1( + lakeformation_glue_database, glue_database, glue_table, table_type, path, use_threads, concurrent_partitioning +): + + table = f"__{glue_table}" + database = lakeformation_glue_database if table_type == "GOVERNED" else glue_database + + # Round 1 - Overwrite Partitioned df = pd.DataFrame({"c0": ["foo", None], "c1": [0, 1]}) wr.s3.to_parquet( df=df, path=path, dataset=True, mode="overwrite", - database=glue_database, - table=glue_table, + database=database, + table=table, + table_type=table_type, partition_cols=["c1"], description="c0+c1", parameters={"num_cols": "2", "num_rows": "2"}, @@ -174,29 +208,31 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "2" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c0+c1" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" - # Round 7 - Overwrite Partitions + # Round 2 - Overwrite Partitions df = pd.DataFrame({"c0": [None, None], "c1": [0, 2]}) wr.s3.to_parquet( df=df, - path=path, dataset=True, mode="overwrite_partitions", - database=glue_database, - table=glue_table, + database=database, + table=table, partition_cols=["c1"], description="c0+c1", parameters={"num_cols": "2", "num_rows": "3"}, @@ -204,30 +240,33 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part concurrent_partitioning=concurrent_partitioning, use_threads=use_threads, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df2.columns) == 2 assert len(df2.index) == 3 assert df2.c1.sum() == 3 - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "3" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c0+c1" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" - # Round 8 - Overwrite Partitions + New Column + Wrong Type + # Round 3 - Overwrite Partitions + New Column + Wrong Type df = pd.DataFrame({"c0": [1, 2], "c1": ["1", "3"], "c2": [True, False]}) wr.s3.to_parquet( df=df, path=path, dataset=True, mode="overwrite_partitions", - database=glue_database, - table=glue_table, + database=database, + table=table, partition_cols=["c1"], description="c0+c1+c2", parameters={"num_cols": "3", "num_rows": "4"}, @@ -235,24 +274,29 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df2.columns) == 3 assert len(df2.index) == 4 assert df2.c1.sum() == 6 - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "3" assert parameters["num_rows"] == "4" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1+c2" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c0+c1+c2" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" assert comments["c2"] == "two" + wr.catalog.delete_table_if_exists(database=database, table=table) + -def test_routine_1(glue_database, glue_table, path): +def test_routine_2(glue_database, glue_table, path): # Round 1 - Warm up df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") @@ -445,3 +489,5 @@ def test_routine_1(glue_database, glue_table, path): assert comments["c0"] == "zero" assert comments["c1"] == "one" assert comments["c2"] == "two" + + wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) diff --git a/tests/test_lakeformation.py b/tests/test_lakeformation.py new file mode 100644 index 000000000..242cb3a0f --- /dev/null +++ b/tests/test_lakeformation.py @@ -0,0 +1,150 @@ +import calendar +import logging +import time + +import pandas as pd + +import awswrangler as wr + +from ._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +def test_lakeformation(path, path2, lakeformation_glue_database, glue_table, glue_table2, use_threads=False): + table = f"__{glue_table}" + table2 = f"__{glue_table2}" + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table) + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table2) + + wr.s3.to_parquet( + df=get_df(governed=True), + path=path, + index=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="overwrite", + table=table, + table_type="GOVERNED", + database=lakeformation_glue_database, + ) + + df = wr.lakeformation.read_sql_table( + table=table, + database=lakeformation_glue_database, + use_threads=use_threads, + ) + assert len(df.index) == 3 + assert len(df.columns) == 14 + assert df["iint32"].sum() == 3 + ensure_data_types(df=df, governed=True) + + # Filter query + df2 = wr.lakeformation.read_sql_query( + sql=f"SELECT * FROM {table} WHERE iint16 = :iint16;", + database=lakeformation_glue_database, + params={"iint16": 1}, + ) + assert len(df2.index) == 1 + + wr.s3.to_csv( + df=get_df_csv(), + path=path2, + index=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="append", + table=table2, + table_type="GOVERNED", + database=lakeformation_glue_database, + ) + # Read within a transaction + transaction_id = wr.lakeformation.begin_transaction(read_only=True) + df3 = wr.lakeformation.read_sql_table( + table=table2, + database=lakeformation_glue_database, + transaction_id=transaction_id, + use_threads=use_threads, + ) + assert df3["int"].sum() == 3 + ensure_data_types_csv(df3, governed=True) + + # Read within a query as of time + query_as_of_time = calendar.timegm(time.gmtime()) + df4 = wr.lakeformation.read_sql_table( + table=table2, + database=lakeformation_glue_database, + query_as_of_time=query_as_of_time, + use_threads=use_threads, + ) + assert len(df4.index) == 3 + + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table) + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table2) + + +def test_lakeformation_multi_transaction( + path, path2, lakeformation_glue_database, glue_table, glue_table2, use_threads=True +): + table = f"__{glue_table}" + table2 = f"__{glue_table2}" + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table) + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table2) + + df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") + transaction_id = wr.lakeformation.begin_transaction(read_only=False) + wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="append", + database=lakeformation_glue_database, + table=table, + table_type="GOVERNED", + transaction_id=transaction_id, + description="c0", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, + columns_comments={"c0": "0"}, + use_threads=use_threads, + ) + + df2 = pd.DataFrame({"c1": [None, 1, None]}, dtype="Int16") + wr.s3.to_parquet( + df=df2, + path=path2, + dataset=True, + mode="append", + database=lakeformation_glue_database, + table=table2, + table_type="GOVERNED", + transaction_id=transaction_id, + description="c1", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, + columns_comments={"c1": "1"}, + use_threads=use_threads, + ) + wr.lakeformation.commit_transaction(transaction_id=transaction_id) + + df3 = wr.lakeformation.read_sql_table( + table=table, + database=lakeformation_glue_database, + use_threads=use_threads, + ) + df4 = wr.lakeformation.read_sql_table( + table=table2, + database=lakeformation_glue_database, + use_threads=use_threads, + ) + + assert df.shape == df3.shape + assert df.c0.sum() == df3.c0.sum() + + assert df2.shape == df4.shape + assert df2.c1.sum() == df4.c1.sum() + + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table) + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table2) diff --git a/tutorials/029 - Lake Formation Governed Tables.ipynb b/tutorials/029 - Lake Formation Governed Tables.ipynb new file mode 100644 index 000000000..571b78a89 --- /dev/null +++ b/tutorials/029 - Lake Formation Governed Tables.ipynb @@ -0,0 +1,441 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.9.1 64-bit ('.venv': venv)", + "metadata": { + "interpreter": { + "hash": "2878c7ae46413c5ab07cafef85a7415922732432fa2f847b9105997e244ed975" + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "source": [ + "[![AWS Data Wrangler](_static/logo.png \"AWS Data Wrangler\")](https://github.com/awslabs/aws-data-wrangler)" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "# AWS Lake Formation - Glue Governed tables" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "### This tutorial assumes that your IAM user/role has the required Lake Formation permissions to create and read AWS Glue Governed tables" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## Table of Contents\n", + "* [1. Read Governed table](#1.-Read-Governed-table)\n", + " * [1.1 Read PartiQL query](#1.1-Read-PartiQL-query)\n", + " * [1.1.1 Read within transaction](#1.1.1-Read-within-transaction)\n", + " * [1.1.2 Read within query as of time](#1.1.2-Read-within-query-as-of-time)\n", + " * [1.2 Read full table](#1.2-Read-full-table)\n", + "* [2. Write Governed table](#2.-Write-Governed-table)\n", + " * [2.1 Create new Governed table](#2.1-Create-new-Governed-table)\n", + " * [2.1.1 CSV table](#2.1.1-CSV-table)\n", + " * [2.1.2 Parquet table](#2.1.2-Parquet-table)\n", + " * [2.2 Overwrite operations](#2.2-Overwrite-operations)\n", + " * [2.2.1 Overwrite](#2.2.1-Overwrite)\n", + " * [2.2.2 Append](#2.2.2-Append)\n", + " * [2.2.3 Create partitioned Governed table](#2.2.3-Create-partitioned-Governed-table)\n", + " * [2.2.4 Overwrite partitions](#2.2.4-Overwrite-partitions)\n", + "* [3. Multiple read/write operations within a transaction](#2.-Multiple-read/write-operations-within-a-transaction)" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "# 1. Read Governed table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## 1.1 Read PartiQL query" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import awswrangler as wr\n", + "\n", + "database = \"gov_db\" # Assumes a Glue database registered with Lake Formation exists in the account\n", + "table = \"gov_table\" # Assumes a Governed table exists in the account\n", + "catalog_id = \"111111111111\" # AWS Account Id\n", + "\n", + "# Note 1: If a transaction_id is not specified, a new transaction is started\n", + "df = wr.lakeformation.read_sql_query(\n", + " sql=f\"SELECT * FROM {table};\",\n", + " database=database,\n", + " catalog_id=catalog_id\n", + ")" + ] + }, + { + "source": [ + "### 1.1.1 Read within transaction" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transaction_id = wr.lakeformation.begin_transaction(read_only=True)\n", + "df = wr.lakeformation.read_sql_query(\n", + " sql=f\"SELECT * FROM {table};\",\n", + " database=database,\n", + " transaction_id=transaction_id\n", + ")" + ] + }, + { + "source": [ + "### 1.1.2 Read within query as of time" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import calendar\n", + "import time\n", + "\n", + "query_as_of_time = query_as_of_time = calendar.timegm(time.gmtime())\n", + "df = wr.lakeformation.read_sql_query(\n", + " sql=f\"SELECT * FROM {table} WHERE id=:id; AND name=:name;\",\n", + " database=database,\n", + " query_as_of_time=query_as_of_time,\n", + " params={\"id\": 1, \"name\": \"Ayoub\"}\n", + ")" + ] + }, + { + "source": [ + "## 1.2 Read full table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = wr.lakeformation.read_sql_table(\n", + " table=table,\n", + " database=database\n", + ")" + ] + }, + { + "source": [ + "# 2. Write Governed table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## 2.1 Create a new Governed table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## Enter your bucket name:" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "\n", + "bucket = getpass.getpass()" + ] + }, + { + "source": [ + "### If a governed table does not exist, it can be created by passing an S3 `path` argument. Make sure your IAM user/role has enough permissions in the Lake Formation database" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "### 2.1.1 CSV table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "table = \"gov_table_csv\"\n", + "\n", + "df=pd.DataFrame({\n", + " \"col\": [1, 2, 3],\n", + " \"col2\": [\"A\", \"A\", \"B\"],\n", + " \"col3\": [None, \"test\", None]\n", + "})\n", + "# Note 1: If a transaction_id is not specified, a new transaction is started\n", + "# Note 2: When creating a new Governed table, `table_type=\"GOVERNED\"` must be specified. Otherwise the default is to create an EXTERNAL_TABLE\n", + "wr.s3.to_csv(\n", + " df=df,\n", + " path=f\"s3://{bucket}/{database}/{table}/\", # S3 path\n", + " dataset=True,\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\"\n", + ")" + ] + }, + { + "source": [ + "### 2.1.2 Parquet table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "table = \"gov_table_parquet\"\n", + "\n", + "df = pd.DataFrame({\"c0\": [0, None]}, dtype=\"Int64\")\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " path=f\"s3://{bucket}/{database}/{table}/\",\n", + " dataset=True,\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\",\n", + " description=\"c0\",\n", + " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index))},\n", + " columns_comments={\"c0\": \"0\"}\n", + ")" + ] + }, + { + "source": [ + "## 2.2 Overwrite operations" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "### 2.2.1 Overwrite" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame({\"c1\": [None, 1, None]}, dtype=\"Int16\")\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " dataset=True,\n", + " mode=\"overwrite\",\n", + " database=database,\n", + " table=table,\n", + " description=\"c1\",\n", + " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index))},\n", + " columns_comments={\"c1\": \"1\"}\n", + ")" + ] + }, + { + "source": [ + "### 2.2.2 Append" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame({\"c1\": [None, 2, None]}, dtype=\"Int8\")\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " dataset=True,\n", + " mode=\"append\",\n", + " database=database,\n", + " table=table,\n", + " description=\"c1\",\n", + " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index) * 2)},\n", + " columns_comments={\"c1\": \"1\"}\n", + ")" + ] + }, + { + "source": [ + "### 2.2.3 Create partitioned Governed table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "table = \"gov_table_parquet_partitioned\"\n", + "\n", + "df = pd.DataFrame({\"c0\": [\"foo\", None], \"c1\": [0, 1]})\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " path=f\"s3://{bucket}/{database}/{table}/\",\n", + " dataset=True,\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\",\n", + " partition_cols=[\"c1\"],\n", + " description=\"c0+c1\",\n", + " parameters={\"num_cols\": \"2\", \"num_rows\": \"2\"},\n", + " columns_comments={\"c0\": \"zero\", \"c1\": \"one\"}\n", + ")" + ] + }, + { + "source": [ + "### 2.2.4 Overwrite partitions" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame({\"c0\": [None, None], \"c1\": [0, 2]})\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " dataset=True,\n", + " mode=\"overwrite_partitions\",\n", + " database=database,\n", + " table=table,\n", + " partition_cols=[\"c1\"],\n", + " description=\"c0+c1\",\n", + " parameters={\"num_cols\": \"2\", \"num_rows\": \"3\"},\n", + " columns_comments={\"c0\": \"zero\", \"c1\": \"one\"}\n", + ")" + ] + }, + { + "source": [ + "# 3. Multiple read/write operations within a transaction" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "read_table = \"gov_table_parquet\"\n", + "write_table = \"gov_table_multi_parquet\"\n", + "\n", + "transaction_id = wr.lakeformation.begin_transaction(read_only=False)\n", + "\n", + "df = pd.DataFrame({\"c0\": [0, None]}, dtype=\"Int64\")\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " path=f\"s3://{bucket}/{database}/{write_table}_1\",\n", + " dataset=True,\n", + " database=database,\n", + " table=f\"{write_table}_1\",\n", + " table_type=\"GOVERNED\",\n", + " transaction_id=transaction_id,\n", + ")\n", + "\n", + "df2 = wr.lakeformation.read_sql_table(\n", + " table=read_table,\n", + " database=database,\n", + " transaction_id=transaction_id,\n", + " use_threads=True\n", + ")\n", + "\n", + "df3 = pd.DataFrame({\"c1\": [None, 1, None]}, dtype=\"Int16\")\n", + "wr.s3.to_parquet(\n", + " df=df2,\n", + " path=f\"s3://{bucket}/{database}/{write_table}_2\",\n", + " dataset=True,\n", + " mode=\"append\",\n", + " database=database,\n", + " table=f\"{write_table}_2\",\n", + " table_type=\"GOVERNED\",\n", + " transaction_id=transaction_id,\n", + ")\n", + "\n", + "wr.lakeformation.commit_transaction(transaction_id=transaction_id)" + ] + } + ] +} \ No newline at end of file