From 458bf266f684f096cd26a729ef6eb2d3beffc02d Mon Sep 17 00:00:00 2001 From: igorborgest Date: Mon, 4 May 2020 19:13:39 -0300 Subject: [PATCH] Add keep_files and ctas_temp_table_name to wr.athena.read_*(). #203 --- awswrangler/athena.py | 76 +++++++++++++++++----- awswrangler/torch.py | 20 +++--- testing/test_awswrangler/test_data_lake.py | 62 +++++++++++++++++- 3 files changed, 130 insertions(+), 28 deletions(-) diff --git a/awswrangler/athena.py b/awswrangler/athena.py index 76cb0a108..671dabd42 100644 --- a/awswrangler/athena.py +++ b/awswrangler/athena.py @@ -370,7 +370,7 @@ def _fix_csv_types(df: pd.DataFrame, parse_dates: List[str], binaries: List[str] return df -def read_sql_query( # pylint: disable=too-many-branches,too-many-locals +def read_sql_query( # pylint: disable=too-many-branches,too-many-locals,too-many-return-statements,too-many-statements sql: str, database: str, ctas_approach: bool = True, @@ -380,6 +380,8 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None, + keep_files: bool = True, + ctas_temp_table_name: Optional[str] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: @@ -454,6 +456,12 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported. kms_key : str, optional For SSE-KMS, this is the KMS key ARN or ID. + keep_files : bool + Should Wrangler delete or keep the staging files produced by Athena? + ctas_temp_table_name : str, optional + The name of the temporary table and also the directory name on S3 where the CTAS result is stored. + If None, it will use the follow random pattern: `f"temp_table_{pyarrow.compat.guid()}"`. + On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -477,7 +485,10 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals _s3_output = _s3_output[:-1] if _s3_output[-1] == "/" else _s3_output name: str = "" if ctas_approach is True: - name = f"temp_table_{pa.compat.guid()}" + if ctas_temp_table_name is not None: + name = catalog.sanitize_table_name(ctas_temp_table_name) + else: + name = f"temp_table_{pa.compat.guid()}" path: str = f"{_s3_output}/{name}" ext_location: str = "\n" if wg_config["enforced"] is True else f",\n external_location = '{path}'\n" sql = ( @@ -506,25 +517,34 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals reason: str = query_response["QueryExecution"]["Status"]["StateChangeReason"] message_error: str = f"Query error: {reason}" raise exceptions.AthenaQueryError(message_error) - dfs: Union[pd.DataFrame, Iterator[pd.DataFrame]] + ret: Union[pd.DataFrame, Iterator[pd.DataFrame]] if ctas_approach is True: catalog.delete_table_if_exists(database=database, table=name, boto3_session=session) manifest_path: str = f"{_s3_output}/tables/{query_id}-manifest.csv" + metadata_path: str = f"{_s3_output}/tables/{query_id}.metadata" _logger.debug("manifest_path: %s", manifest_path) + _logger.debug("metadata_path: %s", metadata_path) + s3.wait_objects_exist(paths=[manifest_path, metadata_path], use_threads=False, boto3_session=session) paths: List[str] = _extract_ctas_manifest_paths(path=manifest_path, boto3_session=session) chunked: Union[bool, int] = False if chunksize is None else chunksize _logger.debug("chunked: %s", chunked) if not paths: if chunked is False: - dfs = pd.DataFrame() - else: - dfs = _utils.empty_generator() - else: - s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session) - dfs = s3.read_parquet( - path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked, categories=categories - ) - return dfs + return pd.DataFrame() + return _utils.empty_generator() + s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session) + ret = s3.read_parquet( + path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked, categories=categories + ) + paths_delete: List[str] = paths + [manifest_path, metadata_path] + _logger.debug(type(ret)) + if chunked is False: + if keep_files is False: + s3.delete_objects(path=paths_delete, use_threads=use_threads, boto3_session=session) + return ret + if keep_files is False: + return _delete_after_iterate(dfs=ret, paths=paths_delete, use_threads=use_threads, boto3_session=session) + return ret dtype, parse_timestamps, parse_dates, converters, binaries = _get_query_metadata( query_execution_id=query_id, categories=categories, boto3_session=session ) @@ -547,10 +567,26 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals boto3_session=session, ) _logger.debug("Start type casting...") - if chunksize is None: - return _fix_csv_types(df=ret, parse_dates=parse_dates, binaries=binaries) _logger.debug(type(ret)) - return _fix_csv_types_generator(dfs=ret, parse_dates=parse_dates, binaries=binaries) + if chunksize is None: + df = _fix_csv_types(df=ret, parse_dates=parse_dates, binaries=binaries) + if keep_files is False: + s3.delete_objects(path=[path, f"{path}.metadata"], use_threads=use_threads, boto3_session=session) + return df + dfs = _fix_csv_types_generator(dfs=ret, parse_dates=parse_dates, binaries=binaries) + if keep_files is False: + return _delete_after_iterate( + dfs=dfs, paths=[path, f"{path}.metadata"], use_threads=use_threads, boto3_session=session + ) + return dfs + + +def _delete_after_iterate( + dfs: Iterator[pd.DataFrame], paths: List[str], use_threads: bool, boto3_session: boto3.Session +) -> Iterator[pd.DataFrame]: + for df in dfs: + yield df + s3.delete_objects(path=paths, use_threads=use_threads, boto3_session=boto3_session) def stop_query_execution(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> None: @@ -638,6 +674,8 @@ def read_sql_table( workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None, + keep_files: bool = True, + ctas_temp_table_name: Optional[str] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: @@ -712,6 +750,12 @@ def read_sql_table( None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'. kms_key : str, optional For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. + keep_files : bool + Should Wrangler delete or keep the staging files produced by Athena? + ctas_temp_table_name : str, optional + The name of the temporary table and also the directory name on S3 where the CTAS result is stored. + If None, it will use the follow random pattern: `f"temp_table_{pyarrow.compat.guid()}"`. + On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -740,6 +784,8 @@ def read_sql_table( workgroup=workgroup, encryption=encryption, kms_key=kms_key, + keep_files=keep_files, + ctas_temp_table_name=ctas_temp_table_name, use_threads=use_threads, boto3_session=boto3_session, ) diff --git a/awswrangler/torch.py b/awswrangler/torch.py index 7d3c47316..70df93f34 100644 --- a/awswrangler/torch.py +++ b/awswrangler/torch.py @@ -28,14 +28,14 @@ class _BaseS3Dataset: def __init__( self, path: Union[str, List[str]], suffix: Optional[str] = None, boto3_session: Optional[boto3.Session] = None ): - """PyTorch Map-Style S3 Dataset. + r"""PyTorch Map-Style S3 Dataset. Parameters ---------- path : Union[str, List[str]] S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). suffix: str, optional - S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png). boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. @@ -160,7 +160,7 @@ def __init__( suffix: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, ): - """PyTorch Amazon S3 Lambda Dataset. + r"""PyTorch Amazon S3 Lambda Dataset. Parameters ---------- @@ -171,7 +171,7 @@ def __init__( label_fn: Callable Function that receives object path (str) and return a torch.Tensor suffix: str, optional - S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png). boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. @@ -212,7 +212,7 @@ def __init__( suffix: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, ): - """PyTorch Amazon S3 Audio Dataset. + r"""PyTorch Amazon S3 Audio Dataset. Read individual WAV audio files stores in Amazon S3 and return them as torch tensors. @@ -237,7 +237,7 @@ def __init__( path : Union[str, List[str]] S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). suffix: str, optional - S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png). boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. @@ -302,7 +302,7 @@ class ImageS3Dataset(_S3PartitionedDataset): """PyTorch Amazon S3 Image Dataset.""" def __init__(self, path: Union[str, List[str]], suffix: str, boto3_session: boto3.Session): - """PyTorch Amazon S3 Image Dataset. + r"""PyTorch Amazon S3 Image Dataset. ImageS3Dataset assumes images are patitioned (within class= folders) in Amazon S3. Each lisited object will be loaded by default Pillow library. @@ -327,7 +327,7 @@ def __init__(self, path: Union[str, List[str]], suffix: str, boto3_session: boto path : Union[str, List[str]] S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). suffix: str, optional - S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png). boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. @@ -350,14 +350,14 @@ def _data_fn(self, data: io.BytesIO) -> Any: class S3IterableDataset(IterableDataset, _BaseS3Dataset): # pylint: disable=abstract-method - """PyTorch Amazon S3 Iterable Dataset. + r"""PyTorch Amazon S3 Iterable Dataset. Parameters ---------- path : Union[str, List[str]] S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]). suffix: str, optional - S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png). + S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png). boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index 99c1df1c6..e9c9834df 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -191,14 +191,51 @@ def test_athena_ctas(bucket, database, kms_key): encryption="SSE_KMS", kms_key=kms_key, s3_output=f"s3://{bucket}/test_athena_ctas_result", + keep_files=False, ) assert len(df.index) == 3 ensure_data_types(df=df, has_list=True) + temp_table = "test_athena_ctas2" + s3_output = f"s3://{bucket}/s3_output/" + final_destination = f"{s3_output}{temp_table}/" + + # keep_files=False + wr.s3.delete_objects(path=s3_output) dfs = wr.athena.read_sql_query( - sql=f"SELECT * FROM test_athena_ctas", database=database, ctas_approach=True, chunksize=1 + sql=f"SELECT * FROM test_athena_ctas", + database=database, + ctas_approach=True, + chunksize=1, + keep_files=False, + ctas_temp_table_name=temp_table, + s3_output=s3_output, ) + assert wr.catalog.does_table_exist(database=database, table=temp_table) is False + assert len(wr.s3.list_objects(path=s3_output)) > 2 + assert len(wr.s3.list_objects(path=final_destination)) > 0 for df in dfs: ensure_data_types(df=df, has_list=True) + assert len(wr.s3.list_objects(path=s3_output)) == 0 + + # keep_files=True + wr.s3.delete_objects(path=s3_output) + dfs = wr.athena.read_sql_query( + sql=f"SELECT * FROM test_athena_ctas", + database=database, + ctas_approach=True, + chunksize=2, + keep_files=True, + ctas_temp_table_name=temp_table, + s3_output=s3_output, + ) + assert wr.catalog.does_table_exist(database=database, table=temp_table) is False + assert len(wr.s3.list_objects(path=s3_output)) > 2 + assert len(wr.s3.list_objects(path=final_destination)) > 0 + for df in dfs: + ensure_data_types(df=df, has_list=True) + assert len(wr.s3.list_objects(path=s3_output)) > 2 + + # Cleaning Up wr.catalog.delete_table_if_exists(database=database, table="test_athena_ctas") wr.s3.delete_objects(path=paths) wr.s3.wait_objects_not_exist(paths=paths) @@ -227,12 +264,17 @@ def test_athena(bucket, database, kms_key, workgroup0, workgroup1): encryption="SSE_KMS", kms_key=kms_key, workgroup=workgroup0, + keep_files=False, ) for df2 in dfs: print(df2) ensure_data_types(df=df2) df = wr.athena.read_sql_query( - sql="SELECT * FROM __test_athena", database=database, ctas_approach=False, workgroup=workgroup1 + sql="SELECT * FROM __test_athena", + database=database, + ctas_approach=False, + workgroup=workgroup1, + keep_files=False, ) assert len(df.index) == 3 ensure_data_types(df=df) @@ -1195,9 +1237,23 @@ def test_athena_encryption( df=df, path=path, dataset=True, mode="overwrite", database=database, table=table, s3_additional_kwargs=None )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) + temp_table = table + "2" + s3_output = f"s3://{bucket}/encryptio_s3_output/" + final_destination = f"{s3_output}{temp_table}/" + wr.s3.delete_objects(path=final_destination) df2 = wr.athena.read_sql_table( - table=table, ctas_approach=True, database=database, encryption=encryption, workgroup=workgroup, kms_key=kms_key + table=table, + ctas_approach=True, + database=database, + encryption=encryption, + workgroup=workgroup, + kms_key=kms_key, + keep_files=True, + ctas_temp_table_name=temp_table, + s3_output=s3_output, ) + assert wr.catalog.does_table_exist(database=database, table=temp_table) is False + assert len(wr.s3.list_objects(path=s3_output)) > 2 print(df2) assert len(df2.index) == 2 assert len(df2.columns) == 2