diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 01484c6e5..7948c02e3 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -260,8 +260,7 @@ def _resolve_query_without_cache_ctas( boto3_session: boto3.Session, pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: - fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"' - ctas_query_info: Dict[str, str] = create_ctas_table( + ctas_query_info: Dict[str, Union[str, _QueryMetadata]] = create_ctas_table( sql=sql, database=database, ctas_table=name, @@ -272,35 +271,14 @@ def _resolve_query_without_cache_ctas( workgroup=workgroup, encryption=encryption, kms_key=kms_key, + wait=True, boto3_session=boto3_session, ) - ctas_query_id: str = ctas_query_info["ctas_query_id"] - _logger.debug("ctas_query_id: %s", ctas_query_id) - try: - query_metadata: _QueryMetadata = _get_query_metadata( - query_execution_id=ctas_query_id, - boto3_session=boto3_session, - categories=categories, - metadata_cache_manager=_cache_manager, - ) - except exceptions.QueryFailed as ex: - msg: str = str(ex) - if "Column name" in msg and "specified more than once" in msg: - raise exceptions.InvalidCtasApproachQuery( - f"Please, define distinct names for your columns OR pass ctas_approach=False. Root error message: {msg}" - ) - if "Column name not specified" in msg: - raise exceptions.InvalidArgumentValue( - "Please, define all columns names in your query. (E.g. 'SELECT MAX(col1) AS max_col1, ...')" - ) - if "Column type is unknown" in msg: - raise exceptions.InvalidArgumentValue( - "Please, don't leave undefined columns types in your query. You can cast to ensure it. " - "(E.g. 'SELECT CAST(NULL AS INTEGER) AS MY_COL, ...')" - ) - raise ex + fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"' + ctas_query_metadata: _QueryMetadata = ctas_query_info["ctas_query_metadata"] # type: ignore + _logger.debug("ctas_query_metadata: %s", ctas_query_metadata) return _fetch_parquet_result( - query_metadata=query_metadata, + query_metadata=ctas_query_metadata, keep_files=keep_files, categories=categories, chunksize=chunksize, diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index f0eb5b6b7..6da1b992c 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -642,7 +642,7 @@ def describe_table( @apply_configs -def create_ctas_table( +def create_ctas_table( # pylint: disable=too-many-locals sql: str, database: str, ctas_table: Optional[str] = None, @@ -658,8 +658,10 @@ def create_ctas_table( data_source: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None, + categories: Optional[List[str]] = None, + wait: bool = False, boto3_session: Optional[boto3.Session] = None, -) -> Dict[str, str]: +) -> Dict[str, Union[str, _QueryMetadata]]: """Create a new table populated with the results of a SELECT query. https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html @@ -703,13 +705,19 @@ def create_ctas_table( Valid values: [None, 'SSE_S3', 'SSE_KMS']. Note: 'CSE_KMS' is not supported. kms_key : str, optional For SSE-KMS, this is the KMS key ARN or ID. + categories: List[str], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. + wait : bool, default False + Whether to wait for the query to finish and return a dictionary with the Query metadata. boto3_session : Optional[boto3.Session], optional Boto3 Session. The default boto3 session is used if boto3_session is None. Returns ------- - Dict[str, str] - A dictionary with the ID of the query, and the CTAS database and table names + Dict[str, Union[str, _QueryMetadata]] + A dictionary with the the CTAS database and table names. + If `wait` is `False`, the query ID is included, otherwise a Query metadata object is added instead. """ ctas_table = catalog.sanitize_table_name(ctas_table) if ctas_table else f"temp_table_{uuid.uuid4().hex}" ctas_database = ctas_database if ctas_database else database @@ -753,7 +761,7 @@ def create_ctas_table( _logger.debug("ctas sql: %s", ctas_sql) try: - query_id: str = _start_query_execution( + query_execution_id: str = _start_query_execution( sql=ctas_sql, wg_config=wg_config, database=database, @@ -775,7 +783,35 @@ def create_ctas_table( f"It is not possible to wrap this query into a CTAS statement. Root error message: {error['Message']}" ) raise ex - return {"ctas_database": ctas_database, "ctas_table": ctas_table, "ctas_query_id": query_id} + + response: Dict[str, Union[str, _QueryMetadata]] = {"ctas_database": ctas_database, "ctas_table": ctas_table} + if wait: + try: + response["ctas_query_metadata"] = _get_query_metadata( + query_execution_id=query_execution_id, + boto3_session=boto3_session, + categories=categories, + metadata_cache_manager=_cache_manager, + ) + except exceptions.QueryFailed as ex: + msg: str = str(ex) + if "Column name" in msg and "specified more than once" in msg: + raise exceptions.InvalidCtasApproachQuery( + f"Please, define distinct names for your columns. Root error message: {msg}" + ) + if "Column name not specified" in msg: + raise exceptions.InvalidArgumentValue( + "Please, define all columns names in your query. (E.g. 'SELECT MAX(col1) AS max_col1, ...')" + ) + if "Column type is unknown" in msg: + raise exceptions.InvalidArgumentValue( + "Please, don't leave undefined columns types in your query. You can cast to ensure it. " + "(E.g. 'SELECT CAST(NULL AS INTEGER) AS MY_COL, ...')" + ) + raise ex + else: + response["ctas_query_id"] = query_execution_id + return response @apply_configs diff --git a/docs/source/api.rst b/docs/source/api.rst index e104156d1..0caecc332 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -121,6 +121,7 @@ Amazon Athena repair_table start_query_execution stop_query_execution + unload wait_query AWS Lake Formation diff --git a/tests/_utils.py b/tests/_utils.py index 281cc7a32..fe65ce2be 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -2,7 +2,7 @@ import time from datetime import datetime from decimal import Decimal -from typing import Dict, Iterator +from typing import Any, Dict, Iterator import boto3 import botocore.exceptions @@ -501,12 +501,16 @@ def ensure_data_types_csv(df, governed=False): assert str(df["par1"].dtype) == "string" -def ensure_athena_ctas_table(ctas_query_info: Dict[str, str], boto3_session: boto3.Session) -> None: - query_metadata = wr.athena._utils._get_query_metadata( - query_execution_id=ctas_query_info["ctas_query_id"], boto3_session=boto3_session +def ensure_athena_ctas_table(ctas_query_info: Dict[str, Any], boto3_session: boto3.Session) -> None: + query_metadata = ( + wr.athena._utils._get_query_metadata( + query_execution_id=ctas_query_info["ctas_query_id"], boto3_session=boto3_session + ) + if "ctas_query_id" in ctas_query_info + else ctas_query_info["ctas_query_metadata"] ) assert query_metadata.raw_payload["Status"]["State"] == "SUCCEEDED" - wr.catalog.delete_table_if_exists(table=ctas_query_info["ctas_table"], database=ctas_query_info["ctas_database"]) + wr.catalog.delete_table_if_exists(database=ctas_query_info["ctas_database"], table=ctas_query_info["ctas_table"]) def ensure_athena_query_metadata(df, ctas_approach=True, encrypted=False): diff --git a/tests/test_athena.py b/tests/test_athena.py index f12198c51..9ea2f530e 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -169,6 +169,7 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c database=glue_database, encryption="SSE_KMS", kms_key=kms_key, + wait=False, ) ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session) @@ -178,6 +179,7 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c database=glue_database, ctas_table=glue_table2, schema_only=True, + wait=True, ) ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session) @@ -187,6 +189,7 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c database=glue_database, storage_format="avro", write_compression="snappy", + wait=False, ) ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session) @@ -196,6 +199,7 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c database=glue_database, ctas_database=glue_ctas_database, partitioning_info=["par0", "par1"], + wait=True, ) ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)