diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 770b97f22..3657e2589 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -374,6 +374,7 @@ def _resolve_query_without_cache_ctas( workgroup: Optional[str], kms_key: Optional[str], wg_config: _WorkGroupConfig, + alt_database: Optional[str], name: Optional[str], use_threads: bool, s3_additional_kwargs: Optional[Dict[str, Any]], @@ -381,8 +382,9 @@ def _resolve_query_without_cache_ctas( ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: path: str = f"{s3_output}/{name}" ext_location: str = "\n" if wg_config.enforced is True else f",\n external_location = '{path}'\n" + fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"' sql = ( - f'CREATE TABLE "{name}"\n' + f"CREATE TABLE {fully_qualified_name}\n" f"WITH(\n" f" format = 'Parquet',\n" f" parquet_compression = 'SNAPPY'" @@ -507,6 +509,7 @@ def _resolve_query_without_cache( encryption: Optional[str], kms_key: Optional[str], keep_files: bool, + ctas_database_name: Optional[str], ctas_temp_table_name: Optional[str], use_threads: bool, s3_additional_kwargs: Optional[Dict[str, Any]], @@ -538,6 +541,7 @@ def _resolve_query_without_cache( workgroup=workgroup, kms_key=kms_key, wg_config=wg_config, + alt_database=ctas_database_name, name=name, use_threads=use_threads, s3_additional_kwargs=s3_additional_kwargs, @@ -575,6 +579,7 @@ def read_sql_query( encryption: Optional[str] = None, kms_key: Optional[str] = None, keep_files: bool = True, + ctas_database_name: Optional[str] = None, ctas_temp_table_name: Optional[str] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, @@ -709,6 +714,9 @@ def read_sql_query( For SSE-KMS, this is the KMS key ARN or ID. keep_files : bool Should Wrangler delete or keep the staging files produced by Athena? + ctas_database_name : str, optional + The name of the alternative database where the CTAS temporary table is stored. + If None, the default `database` is used. ctas_temp_table_name : str, optional The name of the temporary table and also the directory name on S3 where the CTAS result is stored. If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex()}"`. @@ -820,6 +828,7 @@ def read_sql_query( encryption=encryption, kms_key=kms_key, keep_files=keep_files, + ctas_database_name=ctas_database_name, ctas_temp_table_name=ctas_temp_table_name, use_threads=use_threads, s3_additional_kwargs=s3_additional_kwargs, @@ -839,6 +848,7 @@ def read_sql_table( encryption: Optional[str] = None, kms_key: Optional[str] = None, keep_files: bool = True, + ctas_database_name: Optional[str] = None, ctas_temp_table_name: Optional[str] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, @@ -967,6 +977,9 @@ def read_sql_table( For SSE-KMS, this is the KMS key ARN or ID. keep_files : bool Should Wrangler delete or keep the staging files produced by Athena? + ctas_database_name : str, optional + The name of the alternative database where the CTAS temporary table is stored. + If None, the default `database` is used. ctas_temp_table_name : str, optional The name of the temporary table and also the directory name on S3 where the CTAS result is stored. If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex}"`. @@ -1027,6 +1040,7 @@ def read_sql_table( encryption=encryption, kms_key=kms_key, keep_files=keep_files, + ctas_database_name=ctas_database_name, ctas_temp_table_name=ctas_temp_table_name, use_threads=use_threads, boto3_session=boto3_session, diff --git a/tests/conftest.py b/tests/conftest.py index 011fccfca..7f44ff12c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -170,6 +170,16 @@ def redshift_external_schema(cloudformation_outputs, databases_parameters, glue_ return "aws_data_wrangler_external" +@pytest.fixture(scope="function") +def glue_ctas_database(): + name = f"db_{get_time_str_with_random_suffix()}" + print(f"Database name: {name}") + wr.catalog.create_database(name=name) + yield name + wr.catalog.delete_database(name=name) + print(f"Database {name} deleted.") + + @pytest.fixture(scope="function") def glue_table(glue_database: str) -> None: name = f"tbl_{get_time_str_with_random_suffix()}" diff --git a/tests/test_athena.py b/tests/test_athena.py index 353adcecc..ae7e3b154 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -26,7 +26,7 @@ logging.getLogger("awswrangler").setLevel(logging.DEBUG) -def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, kms_key): +def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, glue_ctas_database, kms_key): df = get_df_list() columns_types, partitions_types = wr.catalog.extract_athena_types(df=df, partition_cols=["par0", "par1"]) assert len(columns_types) == 17 @@ -102,6 +102,26 @@ def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, ensure_athena_query_metadata(df=df, ctas_approach=True, encrypted=False) assert len(wr.s3.list_objects(path=path3)) > 2 + # ctas_database_name + wr.s3.delete_objects(path=path3) + dfs = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table}", + database=glue_database, + ctas_approach=True, + chunksize=1, + keep_files=False, + ctas_database_name=glue_ctas_database, + ctas_temp_table_name=glue_table2, + s3_output=path3, + ) + assert wr.catalog.does_table_exist(database=glue_ctas_database, table=glue_table2) is True + assert len(wr.s3.list_objects(path=path3)) > 2 + assert len(wr.s3.list_objects(path=final_destination)) > 0 + for df in dfs: + ensure_data_types(df=df, has_list=True) + ensure_athena_query_metadata(df=df, ctas_approach=True, encrypted=False) + assert len(wr.s3.list_objects(path=path3)) == 0 + def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1): table = f"__{glue_table}"