From a724fa17b5e1d6a21c08d260e32c948f33a01657 Mon Sep 17 00:00:00 2001 From: kukushking <3997468+kukushking@users.noreply.github.com> Date: Wed, 7 Jul 2021 23:32:50 +0100 Subject: [PATCH 1/3] Add CTAS bucketing to wr.athena_read_sql_query --- awswrangler/athena/_read.py | 17 ++++++++++++++++- tests/test_athena.py | 27 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 6c8f431f6..9b275b486 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -6,7 +6,7 @@ import re import sys import uuid -from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Union +from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Union, Tuple import boto3 import botocore.exceptions @@ -385,6 +385,7 @@ def _resolve_query_without_cache_ctas( wg_config: _WorkGroupConfig, alt_database: Optional[str], name: Optional[str], + ctas_bucketing_info: Optional[Tuple[List[str], int]], use_threads: bool, s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: boto3.Session, @@ -392,11 +393,17 @@ def _resolve_query_without_cache_ctas( path: str = f"{s3_output}/{name}" ext_location: str = "\n" if wg_config.enforced is True else f",\n external_location = '{path}'\n" fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"' + bucketing_str = ( + f",\n" + f" bucketed_by = ARRAY{ctas_bucketing_info[0]},\n" + f" bucket_count = {ctas_bucketing_info[1]}" + ) if ctas_bucketing_info else "" sql = ( f"CREATE TABLE {fully_qualified_name}\n" f"WITH(\n" f" format = 'Parquet',\n" f" parquet_compression = 'SNAPPY'" + f"{bucketing_str}" f"{ext_location}" f") AS\n" f"{sql}" @@ -521,6 +528,7 @@ def _resolve_query_without_cache( keep_files: bool, ctas_database_name: Optional[str], ctas_temp_table_name: Optional[str], + ctas_bucketing_info: Optional[Tuple[List[str], int]], use_threads: bool, s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: boto3.Session, @@ -553,6 +561,7 @@ def _resolve_query_without_cache( wg_config=wg_config, alt_database=ctas_database_name, name=name, + ctas_bucketing_info=ctas_bucketing_info, use_threads=use_threads, s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, @@ -593,6 +602,7 @@ def read_sql_query( keep_files: bool = True, ctas_database_name: Optional[str] = None, ctas_temp_table_name: Optional[str] = None, + ctas_bucketing_info: Optional[Tuple[List[str], int]] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, max_cache_seconds: int = 0, @@ -729,6 +739,10 @@ def read_sql_query( ctas_database_name : str, optional The name of the alternative database where the CTAS temporary table is stored. If None, the default `database` is used. + ctas_bucketing_info: Tuple[List[str], int], optional + Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the + second element. + Only `str`, `int` and `bool` are supported as column data types for bucketing. ctas_temp_table_name : str, optional The name of the temporary table and also the directory name on S3 where the CTAS result is stored. If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex()}"`. @@ -841,6 +855,7 @@ def read_sql_query( keep_files=keep_files, ctas_database_name=ctas_database_name, ctas_temp_table_name=ctas_temp_table_name, + ctas_bucketing_info=ctas_bucketing_info, use_threads=use_threads, s3_additional_kwargs=s3_additional_kwargs, boto3_session=session, diff --git a/tests/test_athena.py b/tests/test_athena.py index 7019e782a..0190b1be6 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -121,6 +121,33 @@ def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, assert len(wr.s3.list_objects(path=path3)) == 0 +def test_athena_read_sql_ctas_bucketing(path, path2, glue_table, glue_table2, glue_database, glue_ctas_database): + df = pd.DataFrame({"c0": [0, 1], "c1": ["foo", "bar"]}) + wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + database=glue_database, + table=glue_table, + ) + df_ctas = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table}", + ctas_approach=True, + database=glue_database, + ctas_database_name=glue_ctas_database, + ctas_temp_table_name=glue_table2, + ctas_bucketing_info=(["c0"], 1), + s3_output=path2, + ) + df_no_ctas = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table}", + ctas_approach=False, + database=glue_database, + s3_output=path2, + ) + assert df_ctas.equals(df_no_ctas) + + def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1): wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) wr.s3.to_parquet( From 08d43a7f227bc891ba276834580d9aac62318640 Mon Sep 17 00:00:00 2001 From: kukushking <3997468+kukushking@users.noreply.github.com> Date: Thu, 8 Jul 2021 00:35:38 +0100 Subject: [PATCH 2/3] Add CTAS bucketing to read_sql_table --- awswrangler/athena/_read.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 9b275b486..7258b388a 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -739,14 +739,14 @@ def read_sql_query( ctas_database_name : str, optional The name of the alternative database where the CTAS temporary table is stored. If None, the default `database` is used. - ctas_bucketing_info: Tuple[List[str], int], optional - Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the - second element. - Only `str`, `int` and `bool` are supported as column data types for bucketing. ctas_temp_table_name : str, optional The name of the temporary table and also the directory name on S3 where the CTAS result is stored. If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex()}"`. On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`. + ctas_bucketing_info: Tuple[List[str], int], optional + Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the + second element. + Only `str`, `int` and `bool` are supported as column data types for bucketing. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -876,6 +876,7 @@ def read_sql_table( keep_files: bool = True, ctas_database_name: Optional[str] = None, ctas_temp_table_name: Optional[str] = None, + ctas_bucketing_info: Optional[Tuple[List[str], int]] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, max_cache_seconds: int = 0, @@ -1010,6 +1011,10 @@ def read_sql_table( The name of the temporary table and also the directory name on S3 where the CTAS result is stored. If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex}"`. On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`. + ctas_bucketing_info: Tuple[List[str], int], optional + Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the + second element. + Only `str`, `int` and `bool` are supported as column data types for bucketing. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -1068,6 +1073,7 @@ def read_sql_table( keep_files=keep_files, ctas_database_name=ctas_database_name, ctas_temp_table_name=ctas_temp_table_name, + ctas_bucketing_info=ctas_bucketing_info, use_threads=use_threads, boto3_session=boto3_session, max_cache_seconds=max_cache_seconds, From f452c67f24649cd7cdd380a630f1223c50793a4d Mon Sep 17 00:00:00 2001 From: kukushking <3997468+kukushking@users.noreply.github.com> Date: Thu, 8 Jul 2021 00:39:14 +0100 Subject: [PATCH 3/3] Formatting --- awswrangler/athena/_read.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 7258b388a..795cb5a4b 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -6,7 +6,7 @@ import re import sys import uuid -from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Union, Tuple +from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Tuple, Union import boto3 import botocore.exceptions @@ -394,10 +394,10 @@ def _resolve_query_without_cache_ctas( ext_location: str = "\n" if wg_config.enforced is True else f",\n external_location = '{path}'\n" fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"' bucketing_str = ( - f",\n" - f" bucketed_by = ARRAY{ctas_bucketing_info[0]},\n" - f" bucket_count = {ctas_bucketing_info[1]}" - ) if ctas_bucketing_info else "" + (f",\n" f" bucketed_by = ARRAY{ctas_bucketing_info[0]},\n" f" bucket_count = {ctas_bucketing_info[1]}") + if ctas_bucketing_info + else "" + ) sql = ( f"CREATE TABLE {fully_qualified_name}\n" f"WITH(\n"