Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions awswrangler/athena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
wait_query,
)
from awswrangler.athena._spark import create_spark_session, run_spark_calculation
from awswrangler.athena._statements import (
create_prepared_statement,
delete_prepared_statement,
list_prepared_statements,
)
from awswrangler.athena._read import ( # noqa
get_query_results,
read_sql_query,
Expand Down Expand Up @@ -51,5 +56,8 @@
"stop_query_execution",
"unload",
"wait_query",
"create_prepared_statement",
"list_prepared_statements",
"delete_prepared_statement",
"to_iceberg",
]
34 changes: 27 additions & 7 deletions awswrangler/athena/_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,24 @@
from typing import (
Any,
Dict,
List,
Optional,
Union,
cast,
)

import boto3
import botocore
from typing_extensions import Literal

from awswrangler import _utils, exceptions, typing
from awswrangler._config import apply_configs
from awswrangler._sql_formatter import _process_sql_params

from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results
from ._utils import (
_QUERY_FINAL_STATES,
_QUERY_WAIT_POLLING_DELAY,
_apply_formatter,
_get_workgroup_config,
_start_query_execution,
_WorkGroupConfig,
Expand All @@ -36,7 +38,8 @@ def start_query_execution(
workgroup: Optional[str] = None,
encryption: Optional[str] = None,
kms_key: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
params: Union[Dict[str, Any], List[str], None] = None,
paramstyle: Literal["qmark", "named"] = "named",
boto3_session: Optional[boto3.Session] = None,
athena_cache_settings: Optional[typing.AthenaCacheSettings] = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
Expand Down Expand Up @@ -64,10 +67,25 @@ def start_query_execution(
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
kms_key : str, optional
For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
params: Dict[str, any], optional
Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported.
The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes.
params: Dict[str, any] | List[str], optional
Parameters that will be used for constructing the SQL query.
Only named or question mark parameters are supported.
The parameter style needs to be specified in the ``paramstyle`` parameter.

For ``paramstyle="named"``, this value needs to be a dictionary.
The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain
``:name``.
The formatter will be applied client-side in this scenario.

For ``paramstyle="qmark"``, this value needs to be a list of strings.
The formatter will be applied server-side.
The values are applied sequentially to the parameters in the query in the order in which the parameters occur.
paramstyle: str, optional
Determines the style of ``params``.
Possible values are:

- ``named``
- ``qmark``
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
athena_cache_settings: typing.AthenaCacheSettings, optional
Expand Down Expand Up @@ -103,7 +121,8 @@ def start_query_execution(
>>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...', data_source='...')

"""
sql = _process_sql_params(sql, params)
# Substitute query parameters if applicable
sql, execution_params = _apply_formatter(sql, params, paramstyle)
_logger.debug("Executing query:\n%s", sql)

athena_cache_settings = athena_cache_settings if athena_cache_settings else {}
Expand Down Expand Up @@ -139,6 +158,7 @@ def start_query_execution(
workgroup=workgroup,
encryption=encryption,
kms_key=kms_key,
execution_params=execution_params,
boto3_session=boto3_session,
)
if wait:
Expand Down
10 changes: 7 additions & 3 deletions awswrangler/athena/_executions.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Union,
Expand All @@ -19,7 +20,8 @@ def start_query_execution(
workgroup: Optional[str] = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
params: Optional[Dict[str, Any]] = ...,
params: Union[Dict[str, Any], List[str], None] = ...,
paramstyle: Literal["qmark", "named"] = ...,
boto3_session: Optional[boto3.Session] = ...,
athena_cache_settings: Optional[typing.AthenaCacheSettings] = ...,
athena_query_wait_polling_delay: float = ...,
Expand All @@ -35,7 +37,8 @@ def start_query_execution(
workgroup: Optional[str] = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
params: Optional[Dict[str, Any]] = ...,
params: Union[Dict[str, Any], List[str], None] = ...,
paramstyle: Literal["qmark", "named"] = ...,
boto3_session: Optional[boto3.Session] = ...,
athena_cache_settings: Optional[typing.AthenaCacheSettings] = ...,
athena_query_wait_polling_delay: float = ...,
Expand All @@ -51,7 +54,8 @@ def start_query_execution(
workgroup: Optional[str] = ...,
encryption: Optional[str] = ...,
kms_key: Optional[str] = ...,
params: Optional[Dict[str, Any]] = ...,
params: Union[Dict[str, Any], List[str], None] = ...,
paramstyle: Literal["qmark", "named"] = ...,
boto3_session: Optional[boto3.Session] = ...,
athena_cache_settings: Optional[typing.AthenaCacheSettings] = ...,
athena_query_wait_polling_delay: float = ...,
Expand Down
78 changes: 62 additions & 16 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from awswrangler import _utils, catalog, exceptions, s3, typing
from awswrangler._config import apply_configs
from awswrangler._data_types import cast_pandas_with_athena_types
from awswrangler._sql_formatter import _process_sql_params
from awswrangler.athena._utils import (
_QUERY_WAIT_POLLING_DELAY,
_apply_formatter,
_apply_query_metadata,
_empty_dataframe_response,
_get_query_metadata,
Expand Down Expand Up @@ -287,6 +287,7 @@ def _resolve_query_without_cache_ctas(
s3_additional_kwargs: Optional[Dict[str, Any]],
boto3_session: Optional[boto3.Session],
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
execution_params: Optional[List[str]] = None,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
ctas_query_info: Dict[str, Union[str, _QueryMetadata]] = create_ctas_table(
Expand All @@ -304,6 +305,7 @@ def _resolve_query_without_cache_ctas(
wait=True,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
boto3_session=boto3_session,
execution_params=execution_params,
)
fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"'
ctas_query_metadata = cast(_QueryMetadata, ctas_query_info["ctas_query_metadata"])
Expand Down Expand Up @@ -342,6 +344,7 @@ def _resolve_query_without_cache_unload(
s3_additional_kwargs: Optional[Dict[str, Any]],
boto3_session: Optional[boto3.Session],
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
execution_params: Optional[List[str]] = None,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
query_metadata = _unload(
Expand All @@ -358,6 +361,7 @@ def _resolve_query_without_cache_unload(
boto3_session=boto3_session,
data_source=data_source,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
execution_params=execution_params,
)
if file_format == "PARQUET":
return _fetch_parquet_result(
Expand Down Expand Up @@ -389,6 +393,7 @@ def _resolve_query_without_cache_regular(
athena_query_wait_polling_delay: float,
s3_additional_kwargs: Optional[Dict[str, Any]],
boto3_session: Optional[boto3.Session],
execution_params: Optional[List[str]] = None,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
Expand All @@ -404,6 +409,7 @@ def _resolve_query_without_cache_regular(
workgroup=workgroup,
encryption=encryption,
kms_key=kms_key,
execution_params=execution_params,
boto3_session=boto3_session,
)
_logger.debug("Query id: %s", query_id)
Expand Down Expand Up @@ -450,6 +456,7 @@ def _resolve_query_without_cache(
s3_additional_kwargs: Optional[Dict[str, Any]],
boto3_session: Optional[boto3.Session],
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
execution_params: Optional[List[str]] = None,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
"""
Expand Down Expand Up @@ -483,6 +490,7 @@ def _resolve_query_without_cache(
s3_additional_kwargs=s3_additional_kwargs,
boto3_session=boto3_session,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
execution_params=execution_params,
dtype_backend=dtype_backend,
)
finally:
Expand Down Expand Up @@ -510,6 +518,7 @@ def _resolve_query_without_cache(
s3_additional_kwargs=s3_additional_kwargs,
boto3_session=boto3_session,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
execution_params=execution_params,
dtype_backend=dtype_backend,
)
return _resolve_query_without_cache_regular(
Expand All @@ -527,6 +536,7 @@ def _resolve_query_without_cache(
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
s3_additional_kwargs=s3_additional_kwargs,
boto3_session=boto3_session,
execution_params=execution_params,
dtype_backend=dtype_backend,
)

Expand All @@ -545,6 +555,7 @@ def _unload(
boto3_session: Optional[boto3.Session],
data_source: Optional[str],
athena_query_wait_polling_delay: float,
execution_params: Optional[List[str]],
) -> _QueryMetadata:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
s3_output: str = _get_s3_output(s3_output=path, wg_config=wg_config, boto3_session=boto3_session)
Expand Down Expand Up @@ -576,6 +587,7 @@ def _unload(
encryption=encryption,
kms_key=kms_key,
boto3_session=boto3_session,
execution_params=execution_params,
)
except botocore.exceptions.ClientError as ex:
msg: str = str(ex)
Expand Down Expand Up @@ -735,7 +747,8 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
athena_cache_settings: Optional[typing.AthenaCacheSettings] = None,
data_source: Optional[str] = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
params: Optional[Dict[str, Any]] = None,
params: Union[Dict[str, Any], List[str], None] = None,
paramstyle: Literal["qmark", "named"] = "named",
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -905,10 +918,25 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
athena_query_wait_polling_delay: float, default: 0.25 seconds
Interval in seconds for how often the function will check if the Athena query has completed.
params: Dict[str, any], optional
Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported.
The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes.
params: Dict[str, any] | List[str], optional
Parameters that will be used for constructing the SQL query.
Only named or question mark parameters are supported.
The parameter style needs to be specified in the ``paramstyle`` parameter.

For ``paramstyle="named"``, this value needs to be a dictionary.
The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain
``:name``.
The formatter will be applied client-side in this scenario.

For ``paramstyle="qmark"``, this value needs to be a list of strings.
The formatter will be applied server-side.
The values are applied sequentially to the parameters in the query in the order in which the parameters occur.
paramstyle: str, optional
Determines the style of ``params``.
Possible values are:

- ``named``
- ``qmark``
dtype_backend: str, optional
Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays,
nullable dtypes are used for all dtypes that have a nullable implementation when
Expand Down Expand Up @@ -964,15 +992,15 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
raise exceptions.InvalidArgumentCombination("Only PARQUET file format is supported if unload_approach=True")
chunksize = sys.maxsize if ctas_approach is False and chunksize is True else chunksize

# Substitute query parameters if applicable
sql, execution_params = _apply_formatter(sql, params, paramstyle)

athena_cache_settings = athena_cache_settings if athena_cache_settings else {}
max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0)
max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50)
max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50)
max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100)

# Substitute query parameters
sql = _process_sql_params(sql, params)

max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries)

_cache_manager.max_cache_size = max_local_cache_entries
Expand Down Expand Up @@ -1032,6 +1060,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
s3_additional_kwargs=s3_additional_kwargs,
boto3_session=boto3_session,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
execution_params=execution_params,
dtype_backend=dtype_backend,
)

Expand Down Expand Up @@ -1288,7 +1317,8 @@ def unload(
kms_key: Optional[str] = None,
boto3_session: Optional[boto3.Session] = None,
data_source: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
params: Union[Dict[str, Any], List[str], None] = None,
paramstyle: Literal["qmark", "named"] = "named",
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
) -> _QueryMetadata:
"""Write query results from a SELECT statement to the specified data format using UNLOAD.
Expand Down Expand Up @@ -1325,10 +1355,25 @@ def unload(
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
data_source : str, optional
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
params: Dict[str, any], optional
Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported.
The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes.
params: Dict[str, any] | List[str], optional
Parameters that will be used for constructing the SQL query.
Only named or question mark parameters are supported.
The parameter style needs to be specified in the ``paramstyle`` parameter.

For ``paramstyle="named"``, this value needs to be a dictionary.
The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain
``:name``.
The formatter will be applied client-side in this scenario.

For ``paramstyle="qmark"``, this value needs to be a list of strings.
The formatter will be applied server-side.
The values are applied sequentially to the parameters in the query in the order in which the parameters occur.
paramstyle: str, optional
Determines the style of ``params``.
Possible values are:

- ``named``
- ``qmark``
athena_query_wait_polling_delay: float, default: 0.25 seconds
Interval in seconds for how often the function will check if the Athena query has completed.

Expand All @@ -1346,8 +1391,8 @@ def unload(
... )

"""
# Substitute query parameters
sql = _process_sql_params(sql, params)
# Substitute query parameters if applicable
sql, execution_params = _apply_formatter(sql, params, paramstyle)
return _unload(
sql=sql,
path=path,
Expand All @@ -1362,4 +1407,5 @@ def unload(
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
boto3_session=boto3_session,
data_source=data_source,
execution_params=execution_params,
)
Loading