diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 6c8f431f6..795cb5a4b 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -6,7 +6,7 @@ import re import sys import uuid -from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Union +from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Tuple, Union import boto3 import botocore.exceptions @@ -385,6 +385,7 @@ def _resolve_query_without_cache_ctas( wg_config: _WorkGroupConfig, alt_database: Optional[str], name: Optional[str], + ctas_bucketing_info: Optional[Tuple[List[str], int]], use_threads: bool, s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: boto3.Session, @@ -392,11 +393,17 @@ def _resolve_query_without_cache_ctas( path: str = f"{s3_output}/{name}" ext_location: str = "\n" if wg_config.enforced is True else f",\n external_location = '{path}'\n" fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"' + bucketing_str = ( + (f",\n" f" bucketed_by = ARRAY{ctas_bucketing_info[0]},\n" f" bucket_count = {ctas_bucketing_info[1]}") + if ctas_bucketing_info + else "" + ) sql = ( f"CREATE TABLE {fully_qualified_name}\n" f"WITH(\n" f" format = 'Parquet',\n" f" parquet_compression = 'SNAPPY'" + f"{bucketing_str}" f"{ext_location}" f") AS\n" f"{sql}" @@ -521,6 +528,7 @@ def _resolve_query_without_cache( keep_files: bool, ctas_database_name: Optional[str], ctas_temp_table_name: Optional[str], + ctas_bucketing_info: Optional[Tuple[List[str], int]], use_threads: bool, s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: boto3.Session, @@ -553,6 +561,7 @@ def _resolve_query_without_cache( wg_config=wg_config, alt_database=ctas_database_name, name=name, + ctas_bucketing_info=ctas_bucketing_info, use_threads=use_threads, s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, @@ -593,6 +602,7 @@ def read_sql_query( keep_files: bool = True, ctas_database_name: Optional[str] = None, ctas_temp_table_name: Optional[str] = None, + ctas_bucketing_info: Optional[Tuple[List[str], int]] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, max_cache_seconds: int = 0, @@ -733,6 +743,10 @@ def read_sql_query( The name of the temporary table and also the directory name on S3 where the CTAS result is stored. If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex()}"`. On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`. + ctas_bucketing_info: Tuple[List[str], int], optional + Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the + second element. + Only `str`, `int` and `bool` are supported as column data types for bucketing. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -841,6 +855,7 @@ def read_sql_query( keep_files=keep_files, ctas_database_name=ctas_database_name, ctas_temp_table_name=ctas_temp_table_name, + ctas_bucketing_info=ctas_bucketing_info, use_threads=use_threads, s3_additional_kwargs=s3_additional_kwargs, boto3_session=session, @@ -861,6 +876,7 @@ def read_sql_table( keep_files: bool = True, ctas_database_name: Optional[str] = None, ctas_temp_table_name: Optional[str] = None, + ctas_bucketing_info: Optional[Tuple[List[str], int]] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, max_cache_seconds: int = 0, @@ -995,6 +1011,10 @@ def read_sql_table( The name of the temporary table and also the directory name on S3 where the CTAS result is stored. If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex}"`. On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`. + ctas_bucketing_info: Tuple[List[str], int], optional + Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the + second element. + Only `str`, `int` and `bool` are supported as column data types for bucketing. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -1053,6 +1073,7 @@ def read_sql_table( keep_files=keep_files, ctas_database_name=ctas_database_name, ctas_temp_table_name=ctas_temp_table_name, + ctas_bucketing_info=ctas_bucketing_info, use_threads=use_threads, boto3_session=boto3_session, max_cache_seconds=max_cache_seconds, diff --git a/tests/test_athena.py b/tests/test_athena.py index 7019e782a..0190b1be6 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -121,6 +121,33 @@ def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, assert len(wr.s3.list_objects(path=path3)) == 0 +def test_athena_read_sql_ctas_bucketing(path, path2, glue_table, glue_table2, glue_database, glue_ctas_database): + df = pd.DataFrame({"c0": [0, 1], "c1": ["foo", "bar"]}) + wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + database=glue_database, + table=glue_table, + ) + df_ctas = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table}", + ctas_approach=True, + database=glue_database, + ctas_database_name=glue_ctas_database, + ctas_temp_table_name=glue_table2, + ctas_bucketing_info=(["c0"], 1), + s3_output=path2, + ) + df_no_ctas = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table}", + ctas_approach=False, + database=glue_database, + s3_output=path2, + ) + assert df_ctas.equals(df_no_ctas) + + def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1): wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) wr.s3.to_parquet(