From 0b281fd2cda9d12df493bed26dbe499fd1004bdd Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 6 Jun 2023 14:27:25 -0500 Subject: [PATCH 01/13] Add support for execution parameters --- awswrangler/athena/_executions.py | 12 ++++++- awswrangler/athena/_executions.pyi | 4 +++ awswrangler/athena/_read.py | 33 ++++++++++++++++++-- awswrangler/athena/_read.pyi | 11 +++++++ awswrangler/athena/_utils.py | 7 +++++ tests/unit/test_athena.py | 50 ++++++++++++++++++++++++++++++ 6 files changed, 114 insertions(+), 3 deletions(-) diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index 32652bc07..3f0dd084c 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -4,6 +4,7 @@ from typing import ( Any, Dict, + List, Optional, Union, cast, @@ -42,6 +43,7 @@ def start_query_execution( athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, data_source: Optional[str] = None, wait: bool = False, + execution_params: Optional[List[str]] = None, ) -> Union[str, Dict[str, Any]]: """Start a SQL Query against AWS Athena. @@ -67,7 +69,14 @@ def start_query_execution( params: Dict[str, any], optional Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. Note that for varchar columns and similar, you must surround the value in single quotes. + `:name`. + + Note that this formatter is applied client-side, and the query sent to Athena will include the parameter values. + For a server-side application of parameters, see ``execution_params``. + execution_params: List[str], optional + A list of values for the parameters in a query. + The values are applied sequentially to the parameters in the query in the order in which the parameters occur. + The parameters will be applied server-side in Athena. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. athena_cache_settings: typing.AthenaCacheSettings, optional @@ -139,6 +148,7 @@ def start_query_execution( workgroup=workgroup, encryption=encryption, kms_key=kms_key, + execution_params=execution_params, boto3_session=boto3_session, ) if wait: diff --git a/awswrangler/athena/_executions.pyi b/awswrangler/athena/_executions.pyi index 2d38bc641..fd4407ce5 100644 --- a/awswrangler/athena/_executions.pyi +++ b/awswrangler/athena/_executions.pyi @@ -1,6 +1,7 @@ from typing import ( Any, Dict, + List, Literal, Optional, Union, @@ -25,6 +26,7 @@ def start_query_execution( athena_query_wait_polling_delay: float = ..., data_source: Optional[str] = ..., wait: Literal[False] = ..., + execution_params: Optional[List[str]] = ..., ) -> str: ... @overload def start_query_execution( @@ -41,6 +43,7 @@ def start_query_execution( athena_query_wait_polling_delay: float = ..., data_source: Optional[str] = ..., wait: Literal[True], + execution_params: Optional[List[str]] = ..., ) -> Dict[str, Any]: ... @overload def start_query_execution( @@ -57,6 +60,7 @@ def start_query_execution( athena_query_wait_polling_delay: float = ..., data_source: Optional[str] = ..., wait: bool, + execution_params: Optional[List[str]] = ..., ) -> Union[str, Dict[str, Any]]: ... def stop_query_execution(query_execution_id: str, boto3_session: Optional[boto3.Session] = ...) -> None: ... def wait_query( diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index f39b089b2..33dd93ce4 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -287,6 +287,7 @@ def _resolve_query_without_cache_ctas( s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session], pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: ctas_query_info: Dict[str, Union[str, _QueryMetadata]] = create_ctas_table( @@ -304,6 +305,7 @@ def _resolve_query_without_cache_ctas( wait=True, athena_query_wait_polling_delay=athena_query_wait_polling_delay, boto3_session=boto3_session, + execution_params=execution_params, ) fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"' ctas_query_metadata = cast(_QueryMetadata, ctas_query_info["ctas_query_metadata"]) @@ -342,6 +344,7 @@ def _resolve_query_without_cache_unload( s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session], pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: query_metadata = _unload( @@ -358,6 +361,7 @@ def _resolve_query_without_cache_unload( boto3_session=boto3_session, data_source=data_source, athena_query_wait_polling_delay=athena_query_wait_polling_delay, + execution_params=execution_params, ) if file_format == "PARQUET": return _fetch_parquet_result( @@ -389,6 +393,7 @@ def _resolve_query_without_cache_regular( athena_query_wait_polling_delay: float, s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session], + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup) @@ -404,6 +409,7 @@ def _resolve_query_without_cache_regular( workgroup=workgroup, encryption=encryption, kms_key=kms_key, + execution_params=execution_params, boto3_session=boto3_session, ) _logger.debug("Query id: %s", query_id) @@ -450,6 +456,7 @@ def _resolve_query_without_cache( s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session], pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: """ @@ -483,6 +490,7 @@ def _resolve_query_without_cache( s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, pyarrow_additional_kwargs=pyarrow_additional_kwargs, + execution_params=execution_params, dtype_backend=dtype_backend, ) finally: @@ -510,6 +518,7 @@ def _resolve_query_without_cache( s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, pyarrow_additional_kwargs=pyarrow_additional_kwargs, + execution_params=execution_params, dtype_backend=dtype_backend, ) return _resolve_query_without_cache_regular( @@ -527,6 +536,7 @@ def _resolve_query_without_cache( athena_query_wait_polling_delay=athena_query_wait_polling_delay, s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, + execution_params=execution_params, dtype_backend=dtype_backend, ) @@ -545,6 +555,7 @@ def _unload( boto3_session: Optional[boto3.Session], data_source: Optional[str], athena_query_wait_polling_delay: float, + execution_params: Optional[List[str]], ) -> _QueryMetadata: wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup) s3_output: str = _get_s3_output(s3_output=path, wg_config=wg_config, boto3_session=boto3_session) @@ -576,6 +587,7 @@ def _unload( encryption=encryption, kms_key=kms_key, boto3_session=boto3_session, + execution_params=execution_params, ) except botocore.exceptions.ClientError as ex: msg: str = str(ex) @@ -736,6 +748,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals data_source: Optional[str] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, params: Optional[Dict[str, Any]] = None, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", s3_additional_kwargs: Optional[Dict[str, Any]] = None, pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, @@ -908,7 +921,14 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals params: Dict[str, any], optional Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. Note that for varchar columns and similar, you must surround the value in single quotes. + `:name`. + + Note that this formatter is applied client-side, and the query sent to Athena will include the parameter values. + For a server-side application of parameters, see ``execution_params``. + execution_params: List[str], optional + A list of values for the parameters in a query. + The values are applied sequentially to the parameters in the query in the order in which the parameters occur. + The parameters will be applied server-side in Athena. dtype_backend: str, optional Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays, nullable dtypes are used for all dtypes that have a nullable implementation when @@ -1032,6 +1052,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, pyarrow_additional_kwargs=pyarrow_additional_kwargs, + execution_params=execution_params, dtype_backend=dtype_backend, ) @@ -1289,6 +1310,7 @@ def unload( boto3_session: Optional[boto3.Session] = None, data_source: Optional[str] = None, params: Optional[Dict[str, Any]] = None, + execution_params: Optional[List[str]] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, ) -> _QueryMetadata: """Write query results from a SELECT statement to the specified data format using UNLOAD. @@ -1328,7 +1350,14 @@ def unload( params: Dict[str, any], optional Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. Note that for varchar columns and similar, you must surround the value in single quotes. + `:name`. + + Note that this formatter is applied client-side, and the query sent to Athena will include the parameter values. + For a server-side application of parameters, see ``execution_params``. + execution_params: List[str], optional + A list of values for the parameters in a query. + The values are applied sequentially to the parameters in the query in the order in which the parameters occur. + The parameters will be applied server-side in Athena. athena_query_wait_polling_delay: float, default: 0.25 seconds Interval in seconds for how often the function will check if the Athena query has completed. diff --git a/awswrangler/athena/_read.pyi b/awswrangler/athena/_read.pyi index df5dada13..9de4cce74 100644 --- a/awswrangler/athena/_read.pyi +++ b/awswrangler/athena/_read.pyi @@ -74,6 +74,7 @@ def read_sql_query( # pylint: disable=too-many-arguments data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., params: Optional[Dict[str, Any]] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -100,6 +101,7 @@ def read_sql_query( data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., params: Optional[Dict[str, Any]] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -126,6 +128,7 @@ def read_sql_query( data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., params: Optional[Dict[str, Any]] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -152,6 +155,7 @@ def read_sql_query( data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., params: Optional[Dict[str, Any]] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -178,6 +182,7 @@ def read_sql_query( data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., params: Optional[Dict[str, Any]] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -202,6 +207,7 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -226,6 +232,7 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -250,6 +257,7 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -274,6 +282,7 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -298,6 +307,7 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., + execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -316,5 +326,6 @@ def unload( boto3_session: Optional[boto3.Session] = ..., data_source: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., + execution_params: Optional[List[str]] = ..., athena_query_wait_polling_delay: float = ..., ) -> _QueryMetadata: ... diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 3f9063c33..cb8dd6149 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -82,6 +82,7 @@ def _start_query_execution( workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None, + execution_params: Optional[List[str]] = None, boto3_session: Optional[boto3.Session] = None, ) -> str: args: Dict[str, Any] = {"QueryString": sql} @@ -112,6 +113,9 @@ def _start_query_execution( if workgroup is not None: args["WorkGroup"] = workgroup + if execution_params: + args["ExecutionParameters"] = execution_params + client_athena = _utils.client(service_name="athena", session=boto3_session) _logger.debug("Starting query execution with args: \n%s", pprint.pformat(args)) response = _utils.try_it( @@ -207,6 +211,7 @@ def _get_query_metadata( # pylint: disable=too-many-statements query_execution_payload: Optional[Dict[str, Any]] = None, metadata_cache_manager: Optional[_LocalMetadataCacheManager] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> _QueryMetadata: """Get query metadata.""" @@ -568,6 +573,7 @@ def create_ctas_table( # pylint: disable=too-many-locals categories: Optional[List[str]] = None, wait: bool = False, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + execution_params: Optional[List[str]] = None, boto3_session: Optional[boto3.Session] = None, ) -> Dict[str, Union[str, _QueryMetadata]]: """Create a new table populated with the results of a SELECT query. @@ -721,6 +727,7 @@ def create_ctas_table( # pylint: disable=too-many-locals encryption=encryption, kms_key=kms_key, boto3_session=boto3_session, + execution_params=execution_params, ) except botocore.exceptions.ClientError as ex: error = ex.response["Error"] diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index a453c1886..3a99f256d 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -1,6 +1,7 @@ import datetime import logging import string +from typing import Any from unittest.mock import patch import boto3 @@ -294,6 +295,55 @@ def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1 ) +@pytest.mark.parametrize( + "ctas_approach,unload_approach", + [ + pytest.param(False, False, id="regular"), + pytest.param(True, False, id="ctas"), + pytest.param(False, True, id="unload"), + ], +) +@pytest.mark.parametrize( + "col_name,col_value", [("string", "Washington"), ("iint32", "1"), ("date", "DATE '2020-01-01'")] +) +def test_athena_execution_parameters( + path: str, + path2: str, + glue_database: str, + glue_table: str, + workgroup0: str, + ctas_approach: bool, + unload_approach: bool, + col_name: str, + col_value: Any, +) -> None: + wr.s3.to_parquet( + df=get_df(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + df_out = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table} WHERE {col_name} = ?", + database=glue_database, + ctas_approach=ctas_approach, + unload_approach=unload_approach, + workgroup=workgroup0, + execution_params=[col_value], + keep_files=False, + s3_output=path2, + ) + ensure_data_types(df=df_out) + ensure_athena_query_metadata(df=df_out, ctas_approach=ctas_approach, encrypted=False) + + assert len(df_out) == 1 + + def test_read_sql_query_parameter_formatting_respects_prefixes(path, glue_database, glue_table, workgroup0): wr.s3.to_parquet( df=get_df(), From a40e346d00562323318813656eff2035835c33a5 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 6 Jun 2023 14:51:10 -0500 Subject: [PATCH 02/13] Add support for prepared statements --- awswrangler/athena/__init__.py | 3 + awswrangler/athena/_statements.py | 57 ++++++++++++++++++ tests/unit/test_athena_prepared.py | 93 ++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 awswrangler/athena/_statements.py create mode 100644 tests/unit/test_athena_prepared.py diff --git a/awswrangler/athena/__init__.py b/awswrangler/athena/__init__.py index 24666392b..fd9452ed4 100644 --- a/awswrangler/athena/__init__.py +++ b/awswrangler/athena/__init__.py @@ -7,6 +7,7 @@ wait_query, ) from awswrangler.athena._spark import create_spark_session, run_spark_calculation +from awswrangler.athena._statements import prepare_statement, deallocate_prepared_statement from awswrangler.athena._read import ( # noqa get_query_results, read_sql_query, @@ -51,5 +52,7 @@ "stop_query_execution", "unload", "wait_query", + "prepare_statement", + "deallocate_prepared_statement", "to_iceberg", ] diff --git a/awswrangler/athena/_statements.py b/awswrangler/athena/_statements.py new file mode 100644 index 000000000..8bc45ae13 --- /dev/null +++ b/awswrangler/athena/_statements.py @@ -0,0 +1,57 @@ +"""Amazon Athena Module gathering all functions related to prepared statements.""" + +import logging +from typing import Any, Dict, Optional + +import boto3 + +from awswrangler._config import apply_configs +from awswrangler.athena._executions import start_query_execution +from awswrangler.athena._utils import ( + _QUERY_WAIT_POLLING_DELAY, +) + +_logger: logging.Logger = logging.getLogger(__name__) + + +@apply_configs +def prepare_statement( + sql: str, + statement_name: str, + database: Optional[str] = None, + workgroup: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + data_source: Optional[str] = None, +) -> Dict[str, Any]: + sql_statement = f""" +PREPARE "{statement_name}" FROM +{sql} + """ + return start_query_execution( + sql=sql_statement, + database=database, + workgroup=workgroup, + boto3_session=boto3_session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + data_source=data_source, + wait=True, + ) + + +@apply_configs +def deallocate_prepared_statement( + statement_name: str, + workgroup: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + data_source: Optional[str] = None, +) -> Dict[str, Any]: + return start_query_execution( + sql=f'DEALLOCATE PREPARE "{statement_name}"', + workgroup=workgroup, + boto3_session=boto3_session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + data_source=data_source, + wait=True, + ) diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py new file mode 100644 index 000000000..db505c7b5 --- /dev/null +++ b/tests/unit/test_athena_prepared.py @@ -0,0 +1,93 @@ +import logging + +import pytest + +import awswrangler as wr + +from .._utils import ( + ensure_athena_query_metadata, + ensure_data_types, + get_df, + get_time_str_with_random_suffix, +) + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + +pytestmark = pytest.mark.distributed + + +@pytest.fixture(scope="function") +def statement(workgroup0: str) -> str: + name = f"prepared_statement_{get_time_str_with_random_suffix()}" + yield name + try: + wr.athena.deallocate_prepared_statement(statement_name=name, workgroup=workgroup0) + except wr.exceptions.QueryFailed as e: + if not str(e).startswith(f"PreparedStatement {name} was not found"): + raise e + + +def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) -> None: + wr.athena.prepare_statement( + sql="SELECT 1 as col0", + statement_name=statement, + workgroup=workgroup0, + ) + + wr.athena.deallocate_prepared_statement( + statement_name=statement, + workgroup=workgroup0, + ) + + +def test_athena_execute_prepared_statement( + path: str, + path2: str, + glue_database: str, + glue_table: str, + workgroup0: str, + statement: str, +) -> None: + wr.s3.to_parquet( + df=get_df(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + wr.athena.prepare_statement( + sql=f"SELECT * FROM {glue_table} WHERE string = ?", + statement_name=statement, + database=glue_database, + workgroup=workgroup0, + ) + + df_out1 = wr.athena.read_sql_query( + sql=f"EXECUTE \"{statement}\" USING 'Washington'", + database=glue_database, + ctas_approach=False, + workgroup=workgroup0, + keep_files=False, + s3_output=path2, + ) + df_out2 = wr.athena.read_sql_query( + sql=f"EXECUTE \"{statement}\" USING 'Seattle'", + database=glue_database, + ctas_approach=False, + workgroup=workgroup0, + keep_files=False, + s3_output=path2, + ) + + ensure_data_types(df=df_out1) + ensure_data_types(df=df_out2) + + ensure_athena_query_metadata(df=df_out1, ctas_approach=False, encrypted=False) + ensure_athena_query_metadata(df=df_out2, ctas_approach=False, encrypted=False) + + assert len(df_out1) == 1 + assert len(df_out2) == 1 From 2968fa51fa564abef6c52b46f6a6efe0b9acd303 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 6 Jun 2023 15:20:58 -0500 Subject: [PATCH 03/13] Add list_prepared_statements --- awswrangler/athena/__init__.py | 3 ++- awswrangler/athena/_statements.py | 23 ++++++++++++++++++----- tests/unit/test_athena_prepared.py | 19 ++++++++++++++++++- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/awswrangler/athena/__init__.py b/awswrangler/athena/__init__.py index fd9452ed4..840cd57e7 100644 --- a/awswrangler/athena/__init__.py +++ b/awswrangler/athena/__init__.py @@ -7,7 +7,7 @@ wait_query, ) from awswrangler.athena._spark import create_spark_session, run_spark_calculation -from awswrangler.athena._statements import prepare_statement, deallocate_prepared_statement +from awswrangler.athena._statements import prepare_statement, deallocate_prepared_statement, list_prepared_statements from awswrangler.athena._read import ( # noqa get_query_results, read_sql_query, @@ -53,6 +53,7 @@ "unload", "wait_query", "prepare_statement", + "list_prepared_statements", "deallocate_prepared_statement", "to_iceberg", ] diff --git a/awswrangler/athena/_statements.py b/awswrangler/athena/_statements.py index 8bc45ae13..2195f0b66 100644 --- a/awswrangler/athena/_statements.py +++ b/awswrangler/athena/_statements.py @@ -1,10 +1,11 @@ """Amazon Athena Module gathering all functions related to prepared statements.""" import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, cast import boto3 +from awswrangler import _utils from awswrangler._config import apply_configs from awswrangler.athena._executions import start_query_execution from awswrangler.athena._utils import ( @@ -18,27 +19,38 @@ def prepare_statement( sql: str, statement_name: str, - database: Optional[str] = None, workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, - data_source: Optional[str] = None, ) -> Dict[str, Any]: sql_statement = f""" PREPARE "{statement_name}" FROM {sql} """ + _logger.info(f"Creating prepared statement {statement_name}") return start_query_execution( sql=sql_statement, - database=database, workgroup=workgroup, boto3_session=boto3_session, athena_query_wait_polling_delay=athena_query_wait_polling_delay, - data_source=data_source, wait=True, ) +@apply_configs +def list_prepared_statements(workgroup: str, boto3_session: Optional[boto3.Session] = None) -> List[str]: + athena_client = _utils.client("athena", session=boto3_session) + + response = athena_client.list_prepared_statements(WorkGroup=workgroup) + statements = response["PreparedStatements"] + + while "NextToken" in response: + response = athena_client.list_prepared_statements(WorkGroup=workgroup, NextToken=response["NextToken"]) + statements += response["PreparedStatements"] + + return cast(List[Dict[str, Any]], statements) + + @apply_configs def deallocate_prepared_statement( statement_name: str, @@ -47,6 +59,7 @@ def deallocate_prepared_statement( athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, data_source: Optional[str] = None, ) -> Dict[str, Any]: + _logger.info(f"Deallocating prepared statement {statement_name}") return start_query_execution( sql=f'DEALLOCATE PREPARE "{statement_name}"', workgroup=workgroup, diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py index db505c7b5..d43fdd2dd 100644 --- a/tests/unit/test_athena_prepared.py +++ b/tests/unit/test_athena_prepared.py @@ -40,6 +40,24 @@ def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) - ) +def test_list_prepared_statements(workgroup1: str, statement: str) -> None: + wr.athena.prepare_statement( + sql="SELECT 1 as col0", + statement_name=statement, + workgroup=workgroup1, + ) + + statement_list = wr.athena.list_prepared_statements(workgroup1) + + assert len(statement_list) == 1 + assert statement_list[0]["StatementName"] == statement + + wr.athena.deallocate_prepared_statement(statement, workgroup=workgroup1) + + statement_list = wr.athena.list_prepared_statements(workgroup1) + assert len(statement_list) == 0 + + def test_athena_execute_prepared_statement( path: str, path2: str, @@ -62,7 +80,6 @@ def test_athena_execute_prepared_statement( wr.athena.prepare_statement( sql=f"SELECT * FROM {glue_table} WHERE string = ?", statement_name=statement, - database=glue_database, workgroup=workgroup0, ) From a649273eec22f0761753a47ecd80421ef197eaaa Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 6 Jun 2023 15:24:34 -0500 Subject: [PATCH 04/13] Add test_athena_execute_prepared_statement_with_params --- tests/unit/test_athena_prepared.py | 41 ++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py index d43fdd2dd..f4624380b 100644 --- a/tests/unit/test_athena_prepared.py +++ b/tests/unit/test_athena_prepared.py @@ -108,3 +108,44 @@ def test_athena_execute_prepared_statement( assert len(df_out1) == 1 assert len(df_out2) == 1 + + +def test_athena_execute_prepared_statement_with_params( + path: str, + path2: str, + glue_database: str, + glue_table: str, + workgroup0: str, + statement: str, +) -> None: + wr.s3.to_parquet( + df=get_df(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + wr.athena.prepare_statement( + sql=f"SELECT * FROM {glue_table} WHERE string = ?", + statement_name=statement, + workgroup=workgroup0, + ) + + df_out1 = wr.athena.read_sql_query( + sql=f'EXECUTE "{statement}"', + database=glue_database, + ctas_approach=False, + workgroup=workgroup0, + execution_params=["Washington"], + keep_files=False, + s3_output=path2, + ) + + ensure_data_types(df=df_out1) + ensure_athena_query_metadata(df=df_out1, ctas_approach=False, encrypted=False) + + assert len(df_out1) == 1 From 7068c3c43a43d52eb1ee730de88502c75409cdd3 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 6 Jun 2023 16:00:18 -0500 Subject: [PATCH 05/13] Fix create and deallocate prepared statements --- awswrangler/athena/_read.py | 1 + awswrangler/athena/_statements.py | 82 ++++++++++++++++++++---------- tests/unit/test_athena_prepared.py | 35 ++++++++++++- 3 files changed, 89 insertions(+), 29 deletions(-) diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 33dd93ce4..f20364644 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -1391,4 +1391,5 @@ def unload( athena_query_wait_polling_delay=athena_query_wait_polling_delay, boto3_session=boto3_session, data_source=data_source, + execution_params=execution_params, ) diff --git a/awswrangler/athena/_statements.py b/awswrangler/athena/_statements.py index 2195f0b66..0b3078e1a 100644 --- a/awswrangler/athena/_statements.py +++ b/awswrangler/athena/_statements.py @@ -1,45 +1,76 @@ """Amazon Athena Module gathering all functions related to prepared statements.""" import logging -from typing import Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast import boto3 +from botocore.exceptions import ClientError -from awswrangler import _utils +from awswrangler import _utils, exceptions from awswrangler._config import apply_configs -from awswrangler.athena._executions import start_query_execution -from awswrangler.athena._utils import ( - _QUERY_WAIT_POLLING_DELAY, -) + +if TYPE_CHECKING: + from mypy_boto3_athena.client import AthenaClient _logger: logging.Logger = logging.getLogger(__name__) +def _does_statement_exist( + statement_name: str, + workgroup: str, + athena_client: "AthenaClient", +) -> bool: + try: + athena_client.get_prepared_statement(StatementName=statement_name, WorkGroup=workgroup) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + return False + + raise e + + return True + + @apply_configs def prepare_statement( sql: str, statement_name: str, workgroup: Optional[str] = None, + mode: Literal["update", "error"] = "update", boto3_session: Optional[boto3.Session] = None, - athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, ) -> Dict[str, Any]: - sql_statement = f""" -PREPARE "{statement_name}" FROM -{sql} - """ + if mode not in ["update", "error"]: + raise exceptions.InvalidArgumentValue("`mode` must be one of 'update' or 'error'.") + + athena_client = _utils.client("athena", session=boto3_session) + workgroup = workgroup if workgroup else "primary" + + already_exists = _does_statement_exist(statement_name, workgroup, athena_client) + if already_exists and mode == "error": + raise exceptions.AlreadyExists(f"Prepared statement {statement_name} already exists.") + + if already_exists: + _logger.info(f"Updating prepared statement {statement_name}") + return athena_client.update_prepared_statement( + StatementName=statement_name, + WorkGroup=workgroup, + QueryStatement=sql, + ) + _logger.info(f"Creating prepared statement {statement_name}") - return start_query_execution( - sql=sql_statement, - workgroup=workgroup, - boto3_session=boto3_session, - athena_query_wait_polling_delay=athena_query_wait_polling_delay, - wait=True, + return athena_client.create_prepared_statement( + StatementName=statement_name, + WorkGroup=workgroup, + QueryStatement=sql, ) @apply_configs -def list_prepared_statements(workgroup: str, boto3_session: Optional[boto3.Session] = None) -> List[str]: +def list_prepared_statements( + workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None +) -> List[str]: athena_client = _utils.client("athena", session=boto3_session) + workgroup = workgroup if workgroup else "primary" response = athena_client.list_prepared_statements(WorkGroup=workgroup) statements = response["PreparedStatements"] @@ -56,15 +87,12 @@ def deallocate_prepared_statement( statement_name: str, workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, - athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, - data_source: Optional[str] = None, ) -> Dict[str, Any]: + athena_client = _utils.client("athena", session=boto3_session) + workgroup = workgroup if workgroup else "primary" + _logger.info(f"Deallocating prepared statement {statement_name}") - return start_query_execution( - sql=f'DEALLOCATE PREPARE "{statement_name}"', - workgroup=workgroup, - boto3_session=boto3_session, - athena_query_wait_polling_delay=athena_query_wait_polling_delay, - data_source=data_source, - wait=True, + return athena_client.delete_prepared_statement( + StatementName=statement_name, + WorkGroup=workgroup, ) diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py index f4624380b..cd90437e0 100644 --- a/tests/unit/test_athena_prepared.py +++ b/tests/unit/test_athena_prepared.py @@ -1,6 +1,7 @@ import logging import pytest +from botocore.exceptions import ClientError import awswrangler as wr @@ -22,11 +23,41 @@ def statement(workgroup0: str) -> str: yield name try: wr.athena.deallocate_prepared_statement(statement_name=name, workgroup=workgroup0) - except wr.exceptions.QueryFailed as e: - if not str(e).startswith(f"PreparedStatement {name} was not found"): + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": raise e +def test_update_prepared_statement(workgroup0: str, statement: str) -> None: + wr.athena.prepare_statement( + sql="SELECT 1 AS col0", + statement_name=statement, + workgroup=workgroup0, + ) + + wr.athena.prepare_statement( + sql="SELECT 1 AS col0, 2 AS col1", + statement_name=statement, + workgroup=workgroup0, + ) + + +def test_update_prepared_statement_error(workgroup0: str, statement: str) -> None: + wr.athena.prepare_statement( + sql="SELECT 1 AS col0", + statement_name=statement, + workgroup=workgroup0, + ) + + with pytest.raises(wr.exceptions.AlreadyExists): + wr.athena.prepare_statement( + sql="SELECT 1 AS col0, 2 AS col1", + statement_name=statement, + workgroup=workgroup0, + mode="error", + ) + + def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) -> None: wr.athena.prepare_statement( sql="SELECT 1 as col0", From ab54de68f686a7372b6ad0e8696650a691965a7a Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 6 Jun 2023 16:21:53 -0500 Subject: [PATCH 06/13] Add documentaton --- awswrangler/athena/_statements.py | 77 ++++++++++++++++++++++++++++++ docs/source/api.rst | 3 ++ tests/unit/test_athena_prepared.py | 5 +- 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/awswrangler/athena/_statements.py b/awswrangler/athena/_statements.py index 0b3078e1a..6d911474e 100644 --- a/awswrangler/athena/_statements.py +++ b/awswrangler/athena/_statements.py @@ -39,6 +39,40 @@ def prepare_statement( mode: Literal["update", "error"] = "update", boto3_session: Optional[boto3.Session] = None, ) -> Dict[str, Any]: + """ + Create a SQL statement with the name statement_name to be run at a later time. The statement can include parameters represented by question marks. + + https://docs.aws.amazon.com/athena/latest/ug/sql-prepare.html + + Parameters + ---------- + sql : str + The query string for the prepared statement. + statement_name : str + The name of the prepared statement. + workgroup : str, optional + The name of the workgroup to which the prepared statement belongs. + mode: str + Determines the behaviour if the prepared statement already exists: + + - ``update`` - updates statement if already exists + - ``error`` - throws an error if table exists + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + Dict[str, Any] + Response to `create_prepared_statement `__. + + Examples + -------- + >>> import awswrangler as wr + >>> res = wr.athena.prepare_statement( + ... sql="SELECT * FROM my_table WHERE name = ?", + ... statement_name="statement", + ... ) + """ if mode not in ["update", "error"]: raise exceptions.InvalidArgumentValue("`mode` must be one of 'update' or 'error'.") @@ -69,6 +103,22 @@ def prepare_statement( def list_prepared_statements( workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None ) -> List[str]: + """ + List the prepared statements in the specified workgroup. + + Parameters + ---------- + workgroup: str, optional + The name of the workgroup to which the prepared statement belongs. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + List[Dict[str, Any]] + List of prepared statements in the workgroup. + Each item is a dictionary with the keys ``StatementName`` and ``LastModifiedTime``. + """ athena_client = _utils.client("athena", session=boto3_session) workgroup = workgroup if workgroup else "primary" @@ -88,6 +138,33 @@ def deallocate_prepared_statement( workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, ) -> Dict[str, Any]: + """ + Delete the prepared statement with the specified name from the specified workgroup. + + https://docs.aws.amazon.com/athena/latest/ug/sql-deallocate-prepare.html + + Parameters + ---------- + statement_name : str + The name of the prepared statement. + workgroup : str, optional + The name of the workgroup to which the prepared statement belongs. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + Dict[str, Any] + Response to `delete_prepared_statement `__. + + Examples + -------- + >>> import awswrangler as wr + >>> res = wr.athena.prepare_statement( + ... sql="SELECT * FROM my_table WHERE name = ?", + ... statement_name="statement", + ... ) + """ athena_client = _utils.client("athena", session=boto3_session) workgroup = workgroup if workgroup else "primary" diff --git a/docs/source/api.rst b/docs/source/api.rst index e1873442a..56eb117a7 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -140,6 +140,9 @@ Amazon Athena to_iceberg unload wait_query + prepare_statement + list_prepared_statements + deallocate_prepared_statement AWS Lake Formation ------------------ diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py index cd90437e0..410d5ade6 100644 --- a/tests/unit/test_athena_prepared.py +++ b/tests/unit/test_athena_prepared.py @@ -59,16 +59,17 @@ def test_update_prepared_statement_error(workgroup0: str, statement: str) -> Non def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) -> None: - wr.athena.prepare_statement( + res = wr.athena.prepare_statement( sql="SELECT 1 as col0", statement_name=statement, workgroup=workgroup0, ) - wr.athena.deallocate_prepared_statement( + res2 = wr.athena.deallocate_prepared_statement( statement_name=statement, workgroup=workgroup0, ) + print(res2) def test_list_prepared_statements(workgroup1: str, statement: str) -> None: From 5ad88ff22d3cb38d704107f39782790f72c2a41d Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 6 Jun 2023 16:28:09 -0500 Subject: [PATCH 07/13] remove prints --- tests/unit/test_athena_prepared.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py index 410d5ade6..cd90437e0 100644 --- a/tests/unit/test_athena_prepared.py +++ b/tests/unit/test_athena_prepared.py @@ -59,17 +59,16 @@ def test_update_prepared_statement_error(workgroup0: str, statement: str) -> Non def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) -> None: - res = wr.athena.prepare_statement( + wr.athena.prepare_statement( sql="SELECT 1 as col0", statement_name=statement, workgroup=workgroup0, ) - res2 = wr.athena.deallocate_prepared_statement( + wr.athena.deallocate_prepared_statement( statement_name=statement, workgroup=workgroup0, ) - print(res2) def test_list_prepared_statements(workgroup1: str, statement: str) -> None: From 26b3fc72f98791d78fe123a5f427050159bf18c3 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 7 Jun 2023 09:13:55 -0500 Subject: [PATCH 08/13] Remove return values --- awswrangler/athena/_statements.py | 32 ++++++++++-------------------- tests/unit/test_athena_prepared.py | 9 ++++++++- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/awswrangler/athena/_statements.py b/awswrangler/athena/_statements.py index 6d911474e..9366ce6da 100644 --- a/awswrangler/athena/_statements.py +++ b/awswrangler/athena/_statements.py @@ -38,7 +38,7 @@ def prepare_statement( workgroup: Optional[str] = None, mode: Literal["update", "error"] = "update", boto3_session: Optional[boto3.Session] = None, -) -> Dict[str, Any]: +) -> None: """ Create a SQL statement with the name statement_name to be run at a later time. The statement can include parameters represented by question marks. @@ -60,11 +60,6 @@ def prepare_statement( boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. - Returns - ------- - Dict[str, Any] - Response to `create_prepared_statement `__. - Examples -------- >>> import awswrangler as wr @@ -85,18 +80,18 @@ def prepare_statement( if already_exists: _logger.info(f"Updating prepared statement {statement_name}") - return athena_client.update_prepared_statement( + athena_client.update_prepared_statement( + StatementName=statement_name, + WorkGroup=workgroup, + QueryStatement=sql, + ) + else: + _logger.info(f"Creating prepared statement {statement_name}") + athena_client.create_prepared_statement( StatementName=statement_name, WorkGroup=workgroup, QueryStatement=sql, ) - - _logger.info(f"Creating prepared statement {statement_name}") - return athena_client.create_prepared_statement( - StatementName=statement_name, - WorkGroup=workgroup, - QueryStatement=sql, - ) @apply_configs @@ -137,7 +132,7 @@ def deallocate_prepared_statement( statement_name: str, workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, -) -> Dict[str, Any]: +) -> None: """ Delete the prepared statement with the specified name from the specified workgroup. @@ -152,11 +147,6 @@ def deallocate_prepared_statement( boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. - Returns - ------- - Dict[str, Any] - Response to `delete_prepared_statement `__. - Examples -------- >>> import awswrangler as wr @@ -169,7 +159,7 @@ def deallocate_prepared_statement( workgroup = workgroup if workgroup else "primary" _logger.info(f"Deallocating prepared statement {statement_name}") - return athena_client.delete_prepared_statement( + athena_client.delete_prepared_statement( StatementName=statement_name, WorkGroup=workgroup, ) diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py index cd90437e0..6cf4689d1 100644 --- a/tests/unit/test_athena_prepared.py +++ b/tests/unit/test_athena_prepared.py @@ -1,5 +1,6 @@ import logging +import boto3 import pytest from botocore.exceptions import ClientError @@ -59,12 +60,18 @@ def test_update_prepared_statement_error(workgroup0: str, statement: str) -> Non def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) -> None: + athena_client = boto3.client("athena") + + sql_statement = "SELECT 1 as col0" wr.athena.prepare_statement( - sql="SELECT 1 as col0", + sql=sql_statement, statement_name=statement, workgroup=workgroup0, ) + resp = athena_client.get_prepared_statement(StatementName=statement, WorkGroup=workgroup0) + assert resp["PreparedStatement"]["QueryStatement"] == sql_statement + wr.athena.deallocate_prepared_statement( statement_name=statement, workgroup=workgroup0, From cccd7ad66cf26e926eec98ba4a32cf4139faa5c5 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 7 Jun 2023 12:49:44 -0500 Subject: [PATCH 09/13] Add parameter resolution to tutorial --- tutorials/006 - Amazon Athena.ipynb | 77 +++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tutorials/006 - Amazon Athena.ipynb b/tutorials/006 - Amazon Athena.ipynb index b7c93cda2..ee0b8c874 100644 --- a/tutorials/006 - Amazon Athena.ipynb +++ b/tutorials/006 - Amazon Athena.ipynb @@ -299,6 +299,83 @@ " print(len(df.index))" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameterized queries" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Client-side parameter resolution\n", + "\n", + "The `params` parameter allows client-side resolution of parameters, which are specified with `:col_name`.\n", + "Additionally, Python types will map to the appropriate Athena definitions.\n", + "For example, the value `dt.date(2023, 1, 1)` will resolve to `DATE '2023-01-01`.\n", + "\n", + "For the example below, the following query will be sent to Athena:\n", + "```sql\n", + "SELECT * FROM noaa WHERE S_FLAG = 'E'\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "wr.athena.read_sql_query(\n", + " \"SELECT * FROM noaa WHERE S_FLAG = :flag_value\",\n", + " database=\"awswrangler_test\",\n", + " params={\n", + " \"flag_value\": \"E\",\n", + " },\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Server-side parameter resolution\n", + "\n", + "Alternatively, Athena supports server-side parameter resolution.\n", + "The SQL statement sent to Athena will not contain the values passed in `execution_params`.\n", + "Instead, they will be passed as part of a separate `execution_params` parameter in `boto3`.\n", + "\n", + "The downside of using this approach is that types aren't automatically resolved.\n", + "The values sent to `execution_params` must be strings.\n", + "Therefore, if one of the values is a date, the value passed in `execution_params` has to be `DATE 'XXXX-XX-XX'`.\n", + "\n", + "The upside, however, is that these parameters can be used with prepared statements.\n", + "\n", + "For more information, see \"[Using parameterized queries](https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html)\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "wr.athena.read_sql_query(\n", + " \"SELECT * FROM noaa WHERE S_FLAG = ?\",\n", + " database=\"awswrangler_test\",\n", + " execution_params=[\"E\"],\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, From bc2c3cf17dda9d27342724f168ce017830d0f94f Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 7 Jun 2023 12:55:37 -0500 Subject: [PATCH 10/13] add prepared statements --- tutorials/006 - Amazon Athena.ipynb | 43 +++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tutorials/006 - Amazon Athena.ipynb b/tutorials/006 - Amazon Athena.ipynb index ee0b8c874..0a327e848 100644 --- a/tutorials/006 - Amazon Athena.ipynb +++ b/tutorials/006 - Amazon Athena.ipynb @@ -376,6 +376,49 @@ ")" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepared statements" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wr.athena.prepare_statement(\n", + " sql=\"SELECT * FROM noaa WHERE S_FLAG = ?\",\n", + " statement_name=\"statement\",\n", + ")\n", + "\n", + "# Resolve parameter using Athena execution parameters\n", + "wr.athena.read_sql_query(\n", + " sql=\"EXECUTE statement\",\n", + " database=\"awswrangler_test\",\n", + " execution_params=[\"E\"],\n", + ")\n", + "\n", + "# Resolve parameter using Athena execution parameters (same effect as above)\n", + "wr.athena.read_sql_query(\n", + " sql=\"EXECUTE statement USING ?\",\n", + " database=\"awswrangler_test\",\n", + " execution_params=[\"E\"],\n", + ")\n", + "\n", + "# Resolve parameter using client-side formatter\n", + "wr.athena.read_sql_query(\n", + " sql=\"EXECUTE statement USING :flag_value\",\n", + " database=\"awswrangler_test\",\n", + " params={\n", + " \"flag_value\": \"E\",\n", + " },\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, From 3ed18e0556cd4715ecf4336cc80bb8bfc8635702 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Thu, 8 Jun 2023 13:40:32 -0500 Subject: [PATCH 11/13] Refactor params --- awswrangler/athena/_executions.py | 38 +++++++++------ awswrangler/athena/_executions.pyi | 12 ++--- awswrangler/athena/_read.py | 76 +++++++++++++++++------------ awswrangler/athena/_read.pyi | 29 +++++------ awswrangler/athena/_utils.py | 66 +++++++++++++++++++++++++ tests/unit/test_athena.py | 5 +- tests/unit/test_athena_prepared.py | 3 +- tutorials/006 - Amazon Athena.ipynb | 34 +++++++++---- 8 files changed, 184 insertions(+), 79 deletions(-) diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index 3f0dd084c..ec3e40020 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -12,15 +12,16 @@ import boto3 import botocore +from typing_extensions import Literal from awswrangler import _utils, exceptions, typing from awswrangler._config import apply_configs -from awswrangler._sql_formatter import _process_sql_params from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results from ._utils import ( _QUERY_FINAL_STATES, _QUERY_WAIT_POLLING_DELAY, + _apply_formatter, _get_workgroup_config, _start_query_execution, _WorkGroupConfig, @@ -37,13 +38,13 @@ def start_query_execution( workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, + params: Union[Dict[str, Any], List[str], None] = None, + paramstyle: Literal["qmark", "named"] = "named", boto3_session: Optional[boto3.Session] = None, athena_cache_settings: Optional[typing.AthenaCacheSettings] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, data_source: Optional[str] = None, wait: bool = False, - execution_params: Optional[List[str]] = None, ) -> Union[str, Dict[str, Any]]: """Start a SQL Query against AWS Athena. @@ -66,17 +67,25 @@ def start_query_execution( None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'. kms_key : str, optional For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. - params: Dict[str, any], optional - Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. - The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. - - Note that this formatter is applied client-side, and the query sent to Athena will include the parameter values. - For a server-side application of parameters, see ``execution_params``. - execution_params: List[str], optional - A list of values for the parameters in a query. + params: Dict[str, any] | List[str], optional + Parameters that will be used for constructing the SQL query. + Only named or question mark parameters are supported. + The parameter style needs to be specified in the ``paramstyle`` parameter. + + For ``paramstyle="named"``, this value needs to be a dictionary. + The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain + ``:name``. + The formatter will be applied client-side in this scenario. + + For ``paramstyle="qmark"``, this value needs to be a list of strings. + The formatter will be applied server-side. The values are applied sequentially to the parameters in the query in the order in which the parameters occur. - The parameters will be applied server-side in Athena. + paramstyle: str, optional + Determines the style of ``params``. + Possible values are: + + - ``named`` + - ``qmark`` boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. athena_cache_settings: typing.AthenaCacheSettings, optional @@ -112,7 +121,8 @@ def start_query_execution( >>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...', data_source='...') """ - sql = _process_sql_params(sql, params) + # Substitute query parameters if applicable + sql, execution_params = _apply_formatter(sql, params, paramstyle) _logger.debug("Executing query:\n%s", sql) athena_cache_settings = athena_cache_settings if athena_cache_settings else {} diff --git a/awswrangler/athena/_executions.pyi b/awswrangler/athena/_executions.pyi index fd4407ce5..142cedb74 100644 --- a/awswrangler/athena/_executions.pyi +++ b/awswrangler/athena/_executions.pyi @@ -20,13 +20,13 @@ def start_query_execution( workgroup: Optional[str] = ..., encryption: Optional[str] = ..., kms_key: Optional[str] = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., athena_query_wait_polling_delay: float = ..., data_source: Optional[str] = ..., wait: Literal[False] = ..., - execution_params: Optional[List[str]] = ..., ) -> str: ... @overload def start_query_execution( @@ -37,13 +37,13 @@ def start_query_execution( workgroup: Optional[str] = ..., encryption: Optional[str] = ..., kms_key: Optional[str] = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., athena_query_wait_polling_delay: float = ..., data_source: Optional[str] = ..., wait: Literal[True], - execution_params: Optional[List[str]] = ..., ) -> Dict[str, Any]: ... @overload def start_query_execution( @@ -54,13 +54,13 @@ def start_query_execution( workgroup: Optional[str] = ..., encryption: Optional[str] = ..., kms_key: Optional[str] = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., athena_query_wait_polling_delay: float = ..., data_source: Optional[str] = ..., wait: bool, - execution_params: Optional[List[str]] = ..., ) -> Union[str, Dict[str, Any]]: ... def stop_query_execution(query_execution_id: str, boto3_session: Optional[boto3.Session] = ...) -> None: ... def wait_query( diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index f20364644..1733bd440 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -16,9 +16,9 @@ from awswrangler import _utils, catalog, exceptions, s3, typing from awswrangler._config import apply_configs from awswrangler._data_types import cast_pandas_with_athena_types -from awswrangler._sql_formatter import _process_sql_params from awswrangler.athena._utils import ( _QUERY_WAIT_POLLING_DELAY, + _apply_formatter, _apply_query_metadata, _empty_dataframe_response, _get_query_metadata, @@ -747,8 +747,8 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals athena_cache_settings: Optional[typing.AthenaCacheSettings] = None, data_source: Optional[str] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, - params: Optional[Dict[str, Any]] = None, - execution_params: Optional[List[str]] = None, + params: Union[Dict[str, Any], List[str], None] = None, + paramstyle: Literal["qmark", "named"] = "named", dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", s3_additional_kwargs: Optional[Dict[str, Any]] = None, pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, @@ -918,17 +918,25 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default. athena_query_wait_polling_delay: float, default: 0.25 seconds Interval in seconds for how often the function will check if the Athena query has completed. - params: Dict[str, any], optional - Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. - The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. - - Note that this formatter is applied client-side, and the query sent to Athena will include the parameter values. - For a server-side application of parameters, see ``execution_params``. - execution_params: List[str], optional - A list of values for the parameters in a query. + params: Dict[str, any] | List[str], optional + Parameters that will be used for constructing the SQL query. + Only named or question mark parameters are supported. + The parameter style needs to be specified in the ``paramstyle`` parameter. + + For ``paramstyle="named"``, this value needs to be a dictionary. + The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain + ``:name``. + The formatter will be applied client-side in this scenario. + + For ``paramstyle="qmark"``, this value needs to be a list of strings. + The formatter will be applied server-side. The values are applied sequentially to the parameters in the query in the order in which the parameters occur. - The parameters will be applied server-side in Athena. + paramstyle: str, optional + Determines the style of ``params``. + Possible values are: + + - ``named`` + - ``qmark`` dtype_backend: str, optional Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays, nullable dtypes are used for all dtypes that have a nullable implementation when @@ -984,15 +992,15 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals raise exceptions.InvalidArgumentCombination("Only PARQUET file format is supported if unload_approach=True") chunksize = sys.maxsize if ctas_approach is False and chunksize is True else chunksize + # Substitute query parameters if applicable + sql, execution_params = _apply_formatter(sql, params, paramstyle) + athena_cache_settings = athena_cache_settings if athena_cache_settings else {} max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0) max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50) max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50) max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100) - # Substitute query parameters - sql = _process_sql_params(sql, params) - max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries) _cache_manager.max_cache_size = max_local_cache_entries @@ -1309,8 +1317,8 @@ def unload( kms_key: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, data_source: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, - execution_params: Optional[List[str]] = None, + params: Union[Dict[str, Any], List[str], None] = None, + paramstyle: Literal["qmark", "named"] = "named", athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, ) -> _QueryMetadata: """Write query results from a SELECT statement to the specified data format using UNLOAD. @@ -1347,17 +1355,25 @@ def unload( Boto3 Session. The default boto3 session will be used if boto3_session receive None. data_source : str, optional Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default. - params: Dict[str, any], optional - Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. - The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. - - Note that this formatter is applied client-side, and the query sent to Athena will include the parameter values. - For a server-side application of parameters, see ``execution_params``. - execution_params: List[str], optional - A list of values for the parameters in a query. + params: Dict[str, any] | List[str], optional + Parameters that will be used for constructing the SQL query. + Only named or question mark parameters are supported. + The parameter style needs to be specified in the ``paramstyle`` parameter. + + For ``paramstyle="named"``, this value needs to be a dictionary. + The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain + ``:name``. + The formatter will be applied client-side in this scenario. + + For ``paramstyle="qmark"``, this value needs to be a list of strings. + The formatter will be applied server-side. The values are applied sequentially to the parameters in the query in the order in which the parameters occur. - The parameters will be applied server-side in Athena. + paramstyle: str, optional + Determines the style of ``params``. + Possible values are: + + - ``named`` + - ``qmark`` athena_query_wait_polling_delay: float, default: 0.25 seconds Interval in seconds for how often the function will check if the Athena query has completed. @@ -1375,8 +1391,8 @@ def unload( ... ) """ - # Substitute query parameters - sql = _process_sql_params(sql, params) + # Substitute query parameters if applicable + sql, execution_params = _apply_formatter(sql, params, paramstyle) return _unload( sql=sql, path=path, diff --git a/awswrangler/athena/_read.pyi b/awswrangler/athena/_read.pyi index 9de4cce74..459d35fb7 100644 --- a/awswrangler/athena/_read.pyi +++ b/awswrangler/athena/_read.pyi @@ -73,8 +73,8 @@ def read_sql_query( # pylint: disable=too-many-arguments athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., - execution_params: Optional[List[str]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -100,8 +100,8 @@ def read_sql_query( athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., - execution_params: Optional[List[str]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -127,8 +127,8 @@ def read_sql_query( athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., - execution_params: Optional[List[str]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -154,8 +154,8 @@ def read_sql_query( athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., - execution_params: Optional[List[str]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -181,8 +181,8 @@ def read_sql_query( athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., - execution_params: Optional[List[str]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -207,7 +207,6 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., - execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -232,7 +231,6 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., - execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -257,7 +255,6 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., - execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -282,7 +279,6 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., - execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -307,7 +303,6 @@ def read_sql_table( boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., - execution_params: Optional[List[str]] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -325,7 +320,7 @@ def unload( kms_key: Optional[str] = ..., boto3_session: Optional[boto3.Session] = ..., data_source: Optional[str] = ..., - params: Optional[Dict[str, Any]] = ..., - execution_params: Optional[List[str]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., athena_query_wait_polling_delay: float = ..., ) -> _QueryMetadata: ... diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index cb8dd6149..d5c8759e2 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -17,6 +17,7 @@ Optional, Sequence, Tuple, + TypedDict, Union, cast, ) @@ -28,6 +29,7 @@ from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts, typing from awswrangler._config import apply_configs +from awswrangler._sql_formatter import _process_sql_params from awswrangler.catalog._utils import _catalog_id, _transaction_id from . import _executions @@ -295,6 +297,70 @@ def _apply_query_metadata(df: pd.DataFrame, query_metadata: _QueryMetadata) -> p return df +class _FormatterTypeQMark(TypedDict): + params: List[str] + paramstyle: Literal["qmark"] + + +class _FormatterTypeNamed(TypedDict): + params: Dict[str, Any] + paramstyle: Literal["named"] + + +_FormatterType = Union[_FormatterTypeQMark, _FormatterTypeNamed, None] + + +def _verify_formatter( + params: Union[Dict[str, Any], List[str], None], + paramstyle: Literal["qmark", "named"], +) -> _FormatterType: + if params is None: + return None + + if paramstyle == "named": + if not isinstance(params, dict): + raise exceptions.InvalidArgumentCombination( + f"`params` must be a dict when paramstyle is `named`. Instead, found type {type(params)}." + ) + + return { + "paramstyle": "named", + "params": params, + } + + if paramstyle == "qmark": + if not isinstance(params, list): + raise exceptions.InvalidArgumentCombination( + f"`params` must be a list when paramstyle is `qmark`. Instead, found type {type(params)}." + ) + + return { + "paramstyle": "qmark", + "params": params, + } + + raise exceptions.InvalidArgumentValue(f"`paramstyle` must be either `qmark` or `named`. Found: {paramstyle}.") + + +def _apply_formatter( + sql: str, + params: Union[Dict[str, Any], List[str], None], + paramstyle: Literal["qmark", "named"], +) -> Tuple[str, Optional[List[str]]]: + formatter_settings = _verify_formatter(params, paramstyle) + + if formatter_settings is None: + return sql, None + + if formatter_settings["paramstyle"] == "named": + # Substitute query parameters] + sql = _process_sql_params(sql, formatter_settings["params"]) + + return sql, None + + return sql, formatter_settings["params"] + + def get_named_query_statement( named_query_id: str, boto3_session: Optional[boto3.Session] = None, diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index 10380e36b..8153ec22f 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -333,7 +333,7 @@ def test_athena_orc(path, glue_database, glue_table): @pytest.mark.parametrize( "col_name,col_value", [("string", "Washington"), ("iint32", "1"), ("date", "DATE '2020-01-01'")] ) -def test_athena_execution_parameters( +def test_athena_paramstyle_qmark_parameters( path: str, path2: str, glue_database: str, @@ -361,7 +361,8 @@ def test_athena_execution_parameters( ctas_approach=ctas_approach, unload_approach=unload_approach, workgroup=workgroup0, - execution_params=[col_value], + params=[col_value], + paramstyle="qmark", keep_files=False, s3_output=path2, ) diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py index 6cf4689d1..cb7c5735b 100644 --- a/tests/unit/test_athena_prepared.py +++ b/tests/unit/test_athena_prepared.py @@ -178,7 +178,8 @@ def test_athena_execute_prepared_statement_with_params( database=glue_database, ctas_approach=False, workgroup=workgroup0, - execution_params=["Washington"], + params=["Washington"], + paramstyle="qmark", keep_files=False, s3_output=path2, ) diff --git a/tutorials/006 - Amazon Athena.ipynb b/tutorials/006 - Amazon Athena.ipynb index 0a327e848..178695c3a 100644 --- a/tutorials/006 - Amazon Athena.ipynb +++ b/tutorials/006 - Amazon Athena.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -63,6 +64,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -81,6 +83,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -98,6 +101,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -160,6 +164,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -182,6 +187,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -204,6 +210,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -226,6 +233,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -249,6 +257,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -314,7 +323,7 @@ "source": [ "### Client-side parameter resolution\n", "\n", - "The `params` parameter allows client-side resolution of parameters, which are specified with `:col_name`.\n", + "The `params` parameter allows client-side resolution of parameters, which are specified with `:col_name`, when `paramstyle` is set to `named`.\n", "Additionally, Python types will map to the appropriate Athena definitions.\n", "For example, the value `dt.date(2023, 1, 1)` will resolve to `DATE '2023-01-01`.\n", "\n", @@ -348,13 +357,13 @@ "source": [ "### Server-side parameter resolution\n", "\n", - "Alternatively, Athena supports server-side parameter resolution.\n", - "The SQL statement sent to Athena will not contain the values passed in `execution_params`.\n", - "Instead, they will be passed as part of a separate `execution_params` parameter in `boto3`.\n", + "Alternatively, Athena supports server-side parameter resolution when `paramstyle` is defined as `qmark`.\n", + "The SQL statement sent to Athena will not contain the values passed in `params`.\n", + "Instead, they will be passed as part of a separate `params` parameter in `boto3`.\n", "\n", "The downside of using this approach is that types aren't automatically resolved.\n", - "The values sent to `execution_params` must be strings.\n", - "Therefore, if one of the values is a date, the value passed in `execution_params` has to be `DATE 'XXXX-XX-XX'`.\n", + "The values sent to `params` must be strings.\n", + "Therefore, if one of the values is a date, the value passed in `params` has to be `DATE 'XXXX-XX-XX'`.\n", "\n", "The upside, however, is that these parameters can be used with prepared statements.\n", "\n", @@ -372,7 +381,8 @@ "wr.athena.read_sql_query(\n", " \"SELECT * FROM noaa WHERE S_FLAG = ?\",\n", " database=\"awswrangler_test\",\n", - " execution_params=[\"E\"],\n", + " params=[\"E\"],\n", + " paramstyle=\"qmark\",\n", ")" ] }, @@ -399,14 +409,16 @@ "wr.athena.read_sql_query(\n", " sql=\"EXECUTE statement\",\n", " database=\"awswrangler_test\",\n", - " execution_params=[\"E\"],\n", + " params=[\"E\"],\n", + " paramstyle=\"qmark\",\n", ")\n", "\n", "# Resolve parameter using Athena execution parameters (same effect as above)\n", "wr.athena.read_sql_query(\n", " sql=\"EXECUTE statement USING ?\",\n", " database=\"awswrangler_test\",\n", - " execution_params=[\"E\"],\n", + " params=[\"E\"],\n", + " paramstyle=\"qmark\",\n", ")\n", "\n", "# Resolve parameter using client-side formatter\n", @@ -416,10 +428,12 @@ " params={\n", " \"flag_value\": \"E\",\n", " },\n", + " paramstyle=\"named\",\n", ")" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -440,6 +454,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -460,6 +475,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ From 96413111d79879cde458f740dfd49ea0631c9123 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Fri, 9 Jun 2023 09:49:29 -0500 Subject: [PATCH 12/13] fix PS function names --- awswrangler/athena/__init__.py | 6 +++--- awswrangler/athena/_statements.py | 9 ++++----- docs/source/api.rst | 4 ++-- tests/unit/test_athena_prepared.py | 22 +++++++++++----------- tutorials/006 - Amazon Athena.ipynb | 14 ++++++++++++-- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/awswrangler/athena/__init__.py b/awswrangler/athena/__init__.py index 840cd57e7..2ee432a9b 100644 --- a/awswrangler/athena/__init__.py +++ b/awswrangler/athena/__init__.py @@ -7,7 +7,7 @@ wait_query, ) from awswrangler.athena._spark import create_spark_session, run_spark_calculation -from awswrangler.athena._statements import prepare_statement, deallocate_prepared_statement, list_prepared_statements +from awswrangler.athena._statements import create_prepared_statement, delete_prepared_statement, list_prepared_statements from awswrangler.athena._read import ( # noqa get_query_results, read_sql_query, @@ -52,8 +52,8 @@ "stop_query_execution", "unload", "wait_query", - "prepare_statement", + "create_prepared_statement", "list_prepared_statements", - "deallocate_prepared_statement", + "delete_prepared_statement", "to_iceberg", ] diff --git a/awswrangler/athena/_statements.py b/awswrangler/athena/_statements.py index 9366ce6da..9c5c50212 100644 --- a/awswrangler/athena/_statements.py +++ b/awswrangler/athena/_statements.py @@ -32,7 +32,7 @@ def _does_statement_exist( @apply_configs -def prepare_statement( +def create_prepared_statement( sql: str, statement_name: str, workgroup: Optional[str] = None, @@ -63,7 +63,7 @@ def prepare_statement( Examples -------- >>> import awswrangler as wr - >>> res = wr.athena.prepare_statement( + >>> wr.athena.create_prepared_statement( ... sql="SELECT * FROM my_table WHERE name = ?", ... statement_name="statement", ... ) @@ -128,7 +128,7 @@ def list_prepared_statements( @apply_configs -def deallocate_prepared_statement( +def delete_prepared_statement( statement_name: str, workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, @@ -150,8 +150,7 @@ def deallocate_prepared_statement( Examples -------- >>> import awswrangler as wr - >>> res = wr.athena.prepare_statement( - ... sql="SELECT * FROM my_table WHERE name = ?", + >>> wr.athena.delete_prepared_statement( ... statement_name="statement", ... ) """ diff --git a/docs/source/api.rst b/docs/source/api.rst index 457a31265..d09f34f5c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -144,9 +144,9 @@ Amazon Athena to_iceberg unload wait_query - prepare_statement + create_prepared_statement list_prepared_statements - deallocate_prepared_statement + delete_prepared_statement AWS Lake Formation ------------------ diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py index cb7c5735b..ad9b0cfb7 100644 --- a/tests/unit/test_athena_prepared.py +++ b/tests/unit/test_athena_prepared.py @@ -23,20 +23,20 @@ def statement(workgroup0: str) -> str: name = f"prepared_statement_{get_time_str_with_random_suffix()}" yield name try: - wr.athena.deallocate_prepared_statement(statement_name=name, workgroup=workgroup0) + wr.athena.delete_prepared_statement(statement_name=name, workgroup=workgroup0) except ClientError as e: if e.response["Error"]["Code"] != "ResourceNotFoundException": raise e def test_update_prepared_statement(workgroup0: str, statement: str) -> None: - wr.athena.prepare_statement( + wr.athena.create_prepared_statement( sql="SELECT 1 AS col0", statement_name=statement, workgroup=workgroup0, ) - wr.athena.prepare_statement( + wr.athena.create_prepared_statement( sql="SELECT 1 AS col0, 2 AS col1", statement_name=statement, workgroup=workgroup0, @@ -44,14 +44,14 @@ def test_update_prepared_statement(workgroup0: str, statement: str) -> None: def test_update_prepared_statement_error(workgroup0: str, statement: str) -> None: - wr.athena.prepare_statement( + wr.athena.create_prepared_statement( sql="SELECT 1 AS col0", statement_name=statement, workgroup=workgroup0, ) with pytest.raises(wr.exceptions.AlreadyExists): - wr.athena.prepare_statement( + wr.athena.create_prepared_statement( sql="SELECT 1 AS col0, 2 AS col1", statement_name=statement, workgroup=workgroup0, @@ -63,7 +63,7 @@ def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) - athena_client = boto3.client("athena") sql_statement = "SELECT 1 as col0" - wr.athena.prepare_statement( + wr.athena.create_prepared_statement( sql=sql_statement, statement_name=statement, workgroup=workgroup0, @@ -72,14 +72,14 @@ def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) - resp = athena_client.get_prepared_statement(StatementName=statement, WorkGroup=workgroup0) assert resp["PreparedStatement"]["QueryStatement"] == sql_statement - wr.athena.deallocate_prepared_statement( + wr.athena.delete_prepared_statement( statement_name=statement, workgroup=workgroup0, ) def test_list_prepared_statements(workgroup1: str, statement: str) -> None: - wr.athena.prepare_statement( + wr.athena.create_prepared_statement( sql="SELECT 1 as col0", statement_name=statement, workgroup=workgroup1, @@ -90,7 +90,7 @@ def test_list_prepared_statements(workgroup1: str, statement: str) -> None: assert len(statement_list) == 1 assert statement_list[0]["StatementName"] == statement - wr.athena.deallocate_prepared_statement(statement, workgroup=workgroup1) + wr.athena.delete_prepared_statement(statement, workgroup=workgroup1) statement_list = wr.athena.list_prepared_statements(workgroup1) assert len(statement_list) == 0 @@ -115,7 +115,7 @@ def test_athena_execute_prepared_statement( partition_cols=["par0", "par1"], ) - wr.athena.prepare_statement( + wr.athena.create_prepared_statement( sql=f"SELECT * FROM {glue_table} WHERE string = ?", statement_name=statement, workgroup=workgroup0, @@ -167,7 +167,7 @@ def test_athena_execute_prepared_statement_with_params( partition_cols=["par0", "par1"], ) - wr.athena.prepare_statement( + wr.athena.create_prepared_statement( sql=f"SELECT * FROM {glue_table} WHERE string = ?", statement_name=statement, workgroup=workgroup0, diff --git a/tutorials/006 - Amazon Athena.ipynb b/tutorials/006 - Amazon Athena.ipynb index 178695c3a..84e5e1e7f 100644 --- a/tutorials/006 - Amazon Athena.ipynb +++ b/tutorials/006 - Amazon Athena.ipynb @@ -400,7 +400,7 @@ "metadata": {}, "outputs": [], "source": [ - "wr.athena.prepare_statement(\n", + "wr.athena.create_prepared_statement(\n", " sql=\"SELECT * FROM noaa WHERE S_FLAG = ?\",\n", " statement_name=\"statement\",\n", ")\n", @@ -432,6 +432,16 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Clean up prepared statement\n", + "wr.athena.delete_prepared_statement(statement_name=\"statement\")" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -512,7 +522,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.14" + "version": "3.9.13" } }, "nbformat": 4, From d7b182c1297618475a8ece16e5f04f2a0d48c8bb Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Fri, 9 Jun 2023 09:51:48 -0500 Subject: [PATCH 13/13] fix formatting --- awswrangler/athena/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/awswrangler/athena/__init__.py b/awswrangler/athena/__init__.py index 2ee432a9b..0321a6a93 100644 --- a/awswrangler/athena/__init__.py +++ b/awswrangler/athena/__init__.py @@ -7,7 +7,11 @@ wait_query, ) from awswrangler.athena._spark import create_spark_session, run_spark_calculation -from awswrangler.athena._statements import create_prepared_statement, delete_prepared_statement, list_prepared_statements +from awswrangler.athena._statements import ( + create_prepared_statement, + delete_prepared_statement, + list_prepared_statements, +) from awswrangler.athena._read import ( # noqa get_query_results, read_sql_query,