Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 6 additions & 28 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ def _resolve_query_without_cache_ctas(
boto3_session: boto3.Session,
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"'
ctas_query_info: Dict[str, str] = create_ctas_table(
ctas_query_info: Dict[str, Union[str, _QueryMetadata]] = create_ctas_table(
sql=sql,
database=database,
ctas_table=name,
Expand All @@ -272,35 +271,14 @@ def _resolve_query_without_cache_ctas(
workgroup=workgroup,
encryption=encryption,
kms_key=kms_key,
wait=True,
boto3_session=boto3_session,
)
ctas_query_id: str = ctas_query_info["ctas_query_id"]
_logger.debug("ctas_query_id: %s", ctas_query_id)
try:
query_metadata: _QueryMetadata = _get_query_metadata(
query_execution_id=ctas_query_id,
boto3_session=boto3_session,
categories=categories,
metadata_cache_manager=_cache_manager,
)
except exceptions.QueryFailed as ex:
msg: str = str(ex)
if "Column name" in msg and "specified more than once" in msg:
raise exceptions.InvalidCtasApproachQuery(
f"Please, define distinct names for your columns OR pass ctas_approach=False. Root error message: {msg}"
)
if "Column name not specified" in msg:
raise exceptions.InvalidArgumentValue(
"Please, define all columns names in your query. (E.g. 'SELECT MAX(col1) AS max_col1, ...')"
)
if "Column type is unknown" in msg:
raise exceptions.InvalidArgumentValue(
"Please, don't leave undefined columns types in your query. You can cast to ensure it. "
"(E.g. 'SELECT CAST(NULL AS INTEGER) AS MY_COL, ...')"
)
raise ex
fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"'
ctas_query_metadata: _QueryMetadata = ctas_query_info["ctas_query_metadata"] # type: ignore
_logger.debug("ctas_query_metadata: %s", ctas_query_metadata)
return _fetch_parquet_result(
query_metadata=query_metadata,
query_metadata=ctas_query_metadata,
keep_files=keep_files,
categories=categories,
chunksize=chunksize,
Expand Down
48 changes: 42 additions & 6 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def describe_table(


@apply_configs
def create_ctas_table(
def create_ctas_table( # pylint: disable=too-many-locals
sql: str,
database: str,
ctas_table: Optional[str] = None,
Expand All @@ -658,8 +658,10 @@ def create_ctas_table(
data_source: Optional[str] = None,
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
categories: Optional[List[str]] = None,
wait: bool = False,
boto3_session: Optional[boto3.Session] = None,
) -> Dict[str, str]:
) -> Dict[str, Union[str, _QueryMetadata]]:
"""Create a new table populated with the results of a SELECT query.

https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html
Expand Down Expand Up @@ -703,13 +705,19 @@ def create_ctas_table(
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Note: 'CSE_KMS' is not supported.
kms_key : str, optional
For SSE-KMS, this is the KMS key ARN or ID.
categories: List[str], optional
List of columns names that should be returned as pandas.Categorical.
Recommended for memory restricted environments.
wait : bool, default False
Whether to wait for the query to finish and return a dictionary with the Query metadata.
boto3_session : Optional[boto3.Session], optional
Boto3 Session. The default boto3 session is used if boto3_session is None.

Returns
-------
Dict[str, str]
A dictionary with the ID of the query, and the CTAS database and table names
Dict[str, Union[str, _QueryMetadata]]
A dictionary with the the CTAS database and table names.
If `wait` is `False`, the query ID is included, otherwise a Query metadata object is added instead.
"""
ctas_table = catalog.sanitize_table_name(ctas_table) if ctas_table else f"temp_table_{uuid.uuid4().hex}"
ctas_database = ctas_database if ctas_database else database
Expand Down Expand Up @@ -753,7 +761,7 @@ def create_ctas_table(
_logger.debug("ctas sql: %s", ctas_sql)

try:
query_id: str = _start_query_execution(
query_execution_id: str = _start_query_execution(
sql=ctas_sql,
wg_config=wg_config,
database=database,
Expand All @@ -775,7 +783,35 @@ def create_ctas_table(
f"It is not possible to wrap this query into a CTAS statement. Root error message: {error['Message']}"
)
raise ex
return {"ctas_database": ctas_database, "ctas_table": ctas_table, "ctas_query_id": query_id}

response: Dict[str, Union[str, _QueryMetadata]] = {"ctas_database": ctas_database, "ctas_table": ctas_table}
if wait:
try:
response["ctas_query_metadata"] = _get_query_metadata(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
categories=categories,
metadata_cache_manager=_cache_manager,
)
except exceptions.QueryFailed as ex:
msg: str = str(ex)
if "Column name" in msg and "specified more than once" in msg:
raise exceptions.InvalidCtasApproachQuery(
f"Please, define distinct names for your columns. Root error message: {msg}"
)
if "Column name not specified" in msg:
raise exceptions.InvalidArgumentValue(
"Please, define all columns names in your query. (E.g. 'SELECT MAX(col1) AS max_col1, ...')"
)
if "Column type is unknown" in msg:
raise exceptions.InvalidArgumentValue(
"Please, don't leave undefined columns types in your query. You can cast to ensure it. "
"(E.g. 'SELECT CAST(NULL AS INTEGER) AS MY_COL, ...')"
)
raise ex
else:
response["ctas_query_id"] = query_execution_id
return response


@apply_configs
Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Amazon Athena
repair_table
start_query_execution
stop_query_execution
unload
wait_query

AWS Lake Formation
Expand Down
14 changes: 9 additions & 5 deletions tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from datetime import datetime
from decimal import Decimal
from typing import Dict, Iterator
from typing import Any, Dict, Iterator

import boto3
import botocore.exceptions
Expand Down Expand Up @@ -501,12 +501,16 @@ def ensure_data_types_csv(df, governed=False):
assert str(df["par1"].dtype) == "string"


def ensure_athena_ctas_table(ctas_query_info: Dict[str, str], boto3_session: boto3.Session) -> None:
query_metadata = wr.athena._utils._get_query_metadata(
query_execution_id=ctas_query_info["ctas_query_id"], boto3_session=boto3_session
def ensure_athena_ctas_table(ctas_query_info: Dict[str, Any], boto3_session: boto3.Session) -> None:
query_metadata = (
wr.athena._utils._get_query_metadata(
query_execution_id=ctas_query_info["ctas_query_id"], boto3_session=boto3_session
)
if "ctas_query_id" in ctas_query_info
else ctas_query_info["ctas_query_metadata"]
)
assert query_metadata.raw_payload["Status"]["State"] == "SUCCEEDED"
wr.catalog.delete_table_if_exists(table=ctas_query_info["ctas_table"], database=ctas_query_info["ctas_database"])
wr.catalog.delete_table_if_exists(database=ctas_query_info["ctas_database"], table=ctas_query_info["ctas_table"])


def ensure_athena_query_metadata(df, ctas_approach=True, encrypted=False):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c
database=glue_database,
encryption="SSE_KMS",
kms_key=kms_key,
wait=False,
)
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)

Expand All @@ -178,6 +179,7 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c
database=glue_database,
ctas_table=glue_table2,
schema_only=True,
wait=True,
)
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)

Expand All @@ -187,6 +189,7 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c
database=glue_database,
storage_format="avro",
write_compression="snappy",
wait=False,
)
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)

Expand All @@ -196,6 +199,7 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c
database=glue_database,
ctas_database=glue_ctas_database,
partitioning_info=["par0", "par1"],
wait=True,
)
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)

Expand Down