diff --git a/awswrangler/athena/__init__.py b/awswrangler/athena/__init__.py index 24666392b..0321a6a93 100644 --- a/awswrangler/athena/__init__.py +++ b/awswrangler/athena/__init__.py @@ -7,6 +7,11 @@ wait_query, ) from awswrangler.athena._spark import create_spark_session, run_spark_calculation +from awswrangler.athena._statements import ( + create_prepared_statement, + delete_prepared_statement, + list_prepared_statements, +) from awswrangler.athena._read import ( # noqa get_query_results, read_sql_query, @@ -51,5 +56,8 @@ "stop_query_execution", "unload", "wait_query", + "create_prepared_statement", + "list_prepared_statements", + "delete_prepared_statement", "to_iceberg", ] diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index 32652bc07..ec3e40020 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -4,6 +4,7 @@ from typing import ( Any, Dict, + List, Optional, Union, cast, @@ -11,15 +12,16 @@ import boto3 import botocore +from typing_extensions import Literal from awswrangler import _utils, exceptions, typing from awswrangler._config import apply_configs -from awswrangler._sql_formatter import _process_sql_params from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results from ._utils import ( _QUERY_FINAL_STATES, _QUERY_WAIT_POLLING_DELAY, + _apply_formatter, _get_workgroup_config, _start_query_execution, _WorkGroupConfig, @@ -36,7 +38,8 @@ def start_query_execution( workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, + params: Union[Dict[str, Any], List[str], None] = None, + paramstyle: Literal["qmark", "named"] = "named", boto3_session: Optional[boto3.Session] = None, athena_cache_settings: Optional[typing.AthenaCacheSettings] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, @@ -64,10 +67,25 @@ def start_query_execution( None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'. kms_key : str, optional For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. - params: Dict[str, any], optional - Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. - The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. Note that for varchar columns and similar, you must surround the value in single quotes. + params: Dict[str, any] | List[str], optional + Parameters that will be used for constructing the SQL query. + Only named or question mark parameters are supported. + The parameter style needs to be specified in the ``paramstyle`` parameter. + + For ``paramstyle="named"``, this value needs to be a dictionary. + The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain + ``:name``. + The formatter will be applied client-side in this scenario. + + For ``paramstyle="qmark"``, this value needs to be a list of strings. + The formatter will be applied server-side. + The values are applied sequentially to the parameters in the query in the order in which the parameters occur. + paramstyle: str, optional + Determines the style of ``params``. + Possible values are: + + - ``named`` + - ``qmark`` boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. athena_cache_settings: typing.AthenaCacheSettings, optional @@ -103,7 +121,8 @@ def start_query_execution( >>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...', data_source='...') """ - sql = _process_sql_params(sql, params) + # Substitute query parameters if applicable + sql, execution_params = _apply_formatter(sql, params, paramstyle) _logger.debug("Executing query:\n%s", sql) athena_cache_settings = athena_cache_settings if athena_cache_settings else {} @@ -139,6 +158,7 @@ def start_query_execution( workgroup=workgroup, encryption=encryption, kms_key=kms_key, + execution_params=execution_params, boto3_session=boto3_session, ) if wait: diff --git a/awswrangler/athena/_executions.pyi b/awswrangler/athena/_executions.pyi index 2d38bc641..142cedb74 100644 --- a/awswrangler/athena/_executions.pyi +++ b/awswrangler/athena/_executions.pyi @@ -1,6 +1,7 @@ from typing import ( Any, Dict, + List, Literal, Optional, Union, @@ -19,7 +20,8 @@ def start_query_execution( workgroup: Optional[str] = ..., encryption: Optional[str] = ..., kms_key: Optional[str] = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., athena_query_wait_polling_delay: float = ..., @@ -35,7 +37,8 @@ def start_query_execution( workgroup: Optional[str] = ..., encryption: Optional[str] = ..., kms_key: Optional[str] = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., athena_query_wait_polling_delay: float = ..., @@ -51,7 +54,8 @@ def start_query_execution( workgroup: Optional[str] = ..., encryption: Optional[str] = ..., kms_key: Optional[str] = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., boto3_session: Optional[boto3.Session] = ..., athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., athena_query_wait_polling_delay: float = ..., diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 13444a996..b4605e744 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -16,9 +16,9 @@ from awswrangler import _utils, catalog, exceptions, s3, typing from awswrangler._config import apply_configs from awswrangler._data_types import cast_pandas_with_athena_types -from awswrangler._sql_formatter import _process_sql_params from awswrangler.athena._utils import ( _QUERY_WAIT_POLLING_DELAY, + _apply_formatter, _apply_query_metadata, _empty_dataframe_response, _get_query_metadata, @@ -287,6 +287,7 @@ def _resolve_query_without_cache_ctas( s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session], pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: ctas_query_info: Dict[str, Union[str, _QueryMetadata]] = create_ctas_table( @@ -304,6 +305,7 @@ def _resolve_query_without_cache_ctas( wait=True, athena_query_wait_polling_delay=athena_query_wait_polling_delay, boto3_session=boto3_session, + execution_params=execution_params, ) fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"' ctas_query_metadata = cast(_QueryMetadata, ctas_query_info["ctas_query_metadata"]) @@ -342,6 +344,7 @@ def _resolve_query_without_cache_unload( s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session], pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: query_metadata = _unload( @@ -358,6 +361,7 @@ def _resolve_query_without_cache_unload( boto3_session=boto3_session, data_source=data_source, athena_query_wait_polling_delay=athena_query_wait_polling_delay, + execution_params=execution_params, ) if file_format == "PARQUET": return _fetch_parquet_result( @@ -389,6 +393,7 @@ def _resolve_query_without_cache_regular( athena_query_wait_polling_delay: float, s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session], + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup) @@ -404,6 +409,7 @@ def _resolve_query_without_cache_regular( workgroup=workgroup, encryption=encryption, kms_key=kms_key, + execution_params=execution_params, boto3_session=boto3_session, ) _logger.debug("Query id: %s", query_id) @@ -450,6 +456,7 @@ def _resolve_query_without_cache( s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session], pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: """ @@ -483,6 +490,7 @@ def _resolve_query_without_cache( s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, pyarrow_additional_kwargs=pyarrow_additional_kwargs, + execution_params=execution_params, dtype_backend=dtype_backend, ) finally: @@ -510,6 +518,7 @@ def _resolve_query_without_cache( s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, pyarrow_additional_kwargs=pyarrow_additional_kwargs, + execution_params=execution_params, dtype_backend=dtype_backend, ) return _resolve_query_without_cache_regular( @@ -527,6 +536,7 @@ def _resolve_query_without_cache( athena_query_wait_polling_delay=athena_query_wait_polling_delay, s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, + execution_params=execution_params, dtype_backend=dtype_backend, ) @@ -545,6 +555,7 @@ def _unload( boto3_session: Optional[boto3.Session], data_source: Optional[str], athena_query_wait_polling_delay: float, + execution_params: Optional[List[str]], ) -> _QueryMetadata: wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup) s3_output: str = _get_s3_output(s3_output=path, wg_config=wg_config, boto3_session=boto3_session) @@ -576,6 +587,7 @@ def _unload( encryption=encryption, kms_key=kms_key, boto3_session=boto3_session, + execution_params=execution_params, ) except botocore.exceptions.ClientError as ex: msg: str = str(ex) @@ -735,7 +747,8 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals athena_cache_settings: Optional[typing.AthenaCacheSettings] = None, data_source: Optional[str] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, - params: Optional[Dict[str, Any]] = None, + params: Union[Dict[str, Any], List[str], None] = None, + paramstyle: Literal["qmark", "named"] = "named", dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", s3_additional_kwargs: Optional[Dict[str, Any]] = None, pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, @@ -905,10 +918,25 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default. athena_query_wait_polling_delay: float, default: 0.25 seconds Interval in seconds for how often the function will check if the Athena query has completed. - params: Dict[str, any], optional - Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. - The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. Note that for varchar columns and similar, you must surround the value in single quotes. + params: Dict[str, any] | List[str], optional + Parameters that will be used for constructing the SQL query. + Only named or question mark parameters are supported. + The parameter style needs to be specified in the ``paramstyle`` parameter. + + For ``paramstyle="named"``, this value needs to be a dictionary. + The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain + ``:name``. + The formatter will be applied client-side in this scenario. + + For ``paramstyle="qmark"``, this value needs to be a list of strings. + The formatter will be applied server-side. + The values are applied sequentially to the parameters in the query in the order in which the parameters occur. + paramstyle: str, optional + Determines the style of ``params``. + Possible values are: + + - ``named`` + - ``qmark`` dtype_backend: str, optional Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays, nullable dtypes are used for all dtypes that have a nullable implementation when @@ -964,15 +992,15 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals raise exceptions.InvalidArgumentCombination("Only PARQUET file format is supported if unload_approach=True") chunksize = sys.maxsize if ctas_approach is False and chunksize is True else chunksize + # Substitute query parameters if applicable + sql, execution_params = _apply_formatter(sql, params, paramstyle) + athena_cache_settings = athena_cache_settings if athena_cache_settings else {} max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0) max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50) max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50) max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100) - # Substitute query parameters - sql = _process_sql_params(sql, params) - max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries) _cache_manager.max_cache_size = max_local_cache_entries @@ -1032,6 +1060,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, pyarrow_additional_kwargs=pyarrow_additional_kwargs, + execution_params=execution_params, dtype_backend=dtype_backend, ) @@ -1288,7 +1317,8 @@ def unload( kms_key: Optional[str] = None, boto3_session: Optional[boto3.Session] = None, data_source: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, + params: Union[Dict[str, Any], List[str], None] = None, + paramstyle: Literal["qmark", "named"] = "named", athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, ) -> _QueryMetadata: """Write query results from a SELECT statement to the specified data format using UNLOAD. @@ -1325,10 +1355,25 @@ def unload( Boto3 Session. The default boto3 session will be used if boto3_session receive None. data_source : str, optional Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default. - params: Dict[str, any], optional - Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported. - The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain - `:name`. Note that for varchar columns and similar, you must surround the value in single quotes. + params: Dict[str, any] | List[str], optional + Parameters that will be used for constructing the SQL query. + Only named or question mark parameters are supported. + The parameter style needs to be specified in the ``paramstyle`` parameter. + + For ``paramstyle="named"``, this value needs to be a dictionary. + The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain + ``:name``. + The formatter will be applied client-side in this scenario. + + For ``paramstyle="qmark"``, this value needs to be a list of strings. + The formatter will be applied server-side. + The values are applied sequentially to the parameters in the query in the order in which the parameters occur. + paramstyle: str, optional + Determines the style of ``params``. + Possible values are: + + - ``named`` + - ``qmark`` athena_query_wait_polling_delay: float, default: 0.25 seconds Interval in seconds for how often the function will check if the Athena query has completed. @@ -1346,8 +1391,8 @@ def unload( ... ) """ - # Substitute query parameters - sql = _process_sql_params(sql, params) + # Substitute query parameters if applicable + sql, execution_params = _apply_formatter(sql, params, paramstyle) return _unload( sql=sql, path=path, @@ -1362,4 +1407,5 @@ def unload( athena_query_wait_polling_delay=athena_query_wait_polling_delay, boto3_session=boto3_session, data_source=data_source, + execution_params=execution_params, ) diff --git a/awswrangler/athena/_read.pyi b/awswrangler/athena/_read.pyi index df5dada13..459d35fb7 100644 --- a/awswrangler/athena/_read.pyi +++ b/awswrangler/athena/_read.pyi @@ -73,7 +73,8 @@ def read_sql_query( # pylint: disable=too-many-arguments athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -99,7 +100,8 @@ def read_sql_query( athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -125,7 +127,8 @@ def read_sql_query( athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -151,7 +154,8 @@ def read_sql_query( athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -177,7 +181,8 @@ def read_sql_query( athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., athena_query_wait_polling_delay: float = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., dtype_backend: Literal["numpy_nullable", "pyarrow"] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -315,6 +320,7 @@ def unload( kms_key: Optional[str] = ..., boto3_session: Optional[boto3.Session] = ..., data_source: Optional[str] = ..., - params: Optional[Dict[str, Any]] = ..., + params: Union[Dict[str, Any], List[str], None] = ..., + paramstyle: Literal["qmark", "named"] = ..., athena_query_wait_polling_delay: float = ..., ) -> _QueryMetadata: ... diff --git a/awswrangler/athena/_statements.py b/awswrangler/athena/_statements.py new file mode 100644 index 000000000..9c5c50212 --- /dev/null +++ b/awswrangler/athena/_statements.py @@ -0,0 +1,164 @@ +"""Amazon Athena Module gathering all functions related to prepared statements.""" + +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast + +import boto3 +from botocore.exceptions import ClientError + +from awswrangler import _utils, exceptions +from awswrangler._config import apply_configs + +if TYPE_CHECKING: + from mypy_boto3_athena.client import AthenaClient + +_logger: logging.Logger = logging.getLogger(__name__) + + +def _does_statement_exist( + statement_name: str, + workgroup: str, + athena_client: "AthenaClient", +) -> bool: + try: + athena_client.get_prepared_statement(StatementName=statement_name, WorkGroup=workgroup) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + return False + + raise e + + return True + + +@apply_configs +def create_prepared_statement( + sql: str, + statement_name: str, + workgroup: Optional[str] = None, + mode: Literal["update", "error"] = "update", + boto3_session: Optional[boto3.Session] = None, +) -> None: + """ + Create a SQL statement with the name statement_name to be run at a later time. The statement can include parameters represented by question marks. + + https://docs.aws.amazon.com/athena/latest/ug/sql-prepare.html + + Parameters + ---------- + sql : str + The query string for the prepared statement. + statement_name : str + The name of the prepared statement. + workgroup : str, optional + The name of the workgroup to which the prepared statement belongs. + mode: str + Determines the behaviour if the prepared statement already exists: + + - ``update`` - updates statement if already exists + - ``error`` - throws an error if table exists + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.athena.create_prepared_statement( + ... sql="SELECT * FROM my_table WHERE name = ?", + ... statement_name="statement", + ... ) + """ + if mode not in ["update", "error"]: + raise exceptions.InvalidArgumentValue("`mode` must be one of 'update' or 'error'.") + + athena_client = _utils.client("athena", session=boto3_session) + workgroup = workgroup if workgroup else "primary" + + already_exists = _does_statement_exist(statement_name, workgroup, athena_client) + if already_exists and mode == "error": + raise exceptions.AlreadyExists(f"Prepared statement {statement_name} already exists.") + + if already_exists: + _logger.info(f"Updating prepared statement {statement_name}") + athena_client.update_prepared_statement( + StatementName=statement_name, + WorkGroup=workgroup, + QueryStatement=sql, + ) + else: + _logger.info(f"Creating prepared statement {statement_name}") + athena_client.create_prepared_statement( + StatementName=statement_name, + WorkGroup=workgroup, + QueryStatement=sql, + ) + + +@apply_configs +def list_prepared_statements( + workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None +) -> List[str]: + """ + List the prepared statements in the specified workgroup. + + Parameters + ---------- + workgroup: str, optional + The name of the workgroup to which the prepared statement belongs. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + List[Dict[str, Any]] + List of prepared statements in the workgroup. + Each item is a dictionary with the keys ``StatementName`` and ``LastModifiedTime``. + """ + athena_client = _utils.client("athena", session=boto3_session) + workgroup = workgroup if workgroup else "primary" + + response = athena_client.list_prepared_statements(WorkGroup=workgroup) + statements = response["PreparedStatements"] + + while "NextToken" in response: + response = athena_client.list_prepared_statements(WorkGroup=workgroup, NextToken=response["NextToken"]) + statements += response["PreparedStatements"] + + return cast(List[Dict[str, Any]], statements) + + +@apply_configs +def delete_prepared_statement( + statement_name: str, + workgroup: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, +) -> None: + """ + Delete the prepared statement with the specified name from the specified workgroup. + + https://docs.aws.amazon.com/athena/latest/ug/sql-deallocate-prepare.html + + Parameters + ---------- + statement_name : str + The name of the prepared statement. + workgroup : str, optional + The name of the workgroup to which the prepared statement belongs. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.athena.delete_prepared_statement( + ... statement_name="statement", + ... ) + """ + athena_client = _utils.client("athena", session=boto3_session) + workgroup = workgroup if workgroup else "primary" + + _logger.info(f"Deallocating prepared statement {statement_name}") + athena_client.delete_prepared_statement( + StatementName=statement_name, + WorkGroup=workgroup, + ) diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 3f9063c33..d5c8759e2 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -17,6 +17,7 @@ Optional, Sequence, Tuple, + TypedDict, Union, cast, ) @@ -28,6 +29,7 @@ from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts, typing from awswrangler._config import apply_configs +from awswrangler._sql_formatter import _process_sql_params from awswrangler.catalog._utils import _catalog_id, _transaction_id from . import _executions @@ -82,6 +84,7 @@ def _start_query_execution( workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None, + execution_params: Optional[List[str]] = None, boto3_session: Optional[boto3.Session] = None, ) -> str: args: Dict[str, Any] = {"QueryString": sql} @@ -112,6 +115,9 @@ def _start_query_execution( if workgroup is not None: args["WorkGroup"] = workgroup + if execution_params: + args["ExecutionParameters"] = execution_params + client_athena = _utils.client(service_name="athena", session=boto3_session) _logger.debug("Starting query execution with args: \n%s", pprint.pformat(args)) response = _utils.try_it( @@ -207,6 +213,7 @@ def _get_query_metadata( # pylint: disable=too-many-statements query_execution_payload: Optional[Dict[str, Any]] = None, metadata_cache_manager: Optional[_LocalMetadataCacheManager] = None, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + execution_params: Optional[List[str]] = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", ) -> _QueryMetadata: """Get query metadata.""" @@ -290,6 +297,70 @@ def _apply_query_metadata(df: pd.DataFrame, query_metadata: _QueryMetadata) -> p return df +class _FormatterTypeQMark(TypedDict): + params: List[str] + paramstyle: Literal["qmark"] + + +class _FormatterTypeNamed(TypedDict): + params: Dict[str, Any] + paramstyle: Literal["named"] + + +_FormatterType = Union[_FormatterTypeQMark, _FormatterTypeNamed, None] + + +def _verify_formatter( + params: Union[Dict[str, Any], List[str], None], + paramstyle: Literal["qmark", "named"], +) -> _FormatterType: + if params is None: + return None + + if paramstyle == "named": + if not isinstance(params, dict): + raise exceptions.InvalidArgumentCombination( + f"`params` must be a dict when paramstyle is `named`. Instead, found type {type(params)}." + ) + + return { + "paramstyle": "named", + "params": params, + } + + if paramstyle == "qmark": + if not isinstance(params, list): + raise exceptions.InvalidArgumentCombination( + f"`params` must be a list when paramstyle is `qmark`. Instead, found type {type(params)}." + ) + + return { + "paramstyle": "qmark", + "params": params, + } + + raise exceptions.InvalidArgumentValue(f"`paramstyle` must be either `qmark` or `named`. Found: {paramstyle}.") + + +def _apply_formatter( + sql: str, + params: Union[Dict[str, Any], List[str], None], + paramstyle: Literal["qmark", "named"], +) -> Tuple[str, Optional[List[str]]]: + formatter_settings = _verify_formatter(params, paramstyle) + + if formatter_settings is None: + return sql, None + + if formatter_settings["paramstyle"] == "named": + # Substitute query parameters] + sql = _process_sql_params(sql, formatter_settings["params"]) + + return sql, None + + return sql, formatter_settings["params"] + + def get_named_query_statement( named_query_id: str, boto3_session: Optional[boto3.Session] = None, @@ -568,6 +639,7 @@ def create_ctas_table( # pylint: disable=too-many-locals categories: Optional[List[str]] = None, wait: bool = False, athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + execution_params: Optional[List[str]] = None, boto3_session: Optional[boto3.Session] = None, ) -> Dict[str, Union[str, _QueryMetadata]]: """Create a new table populated with the results of a SELECT query. @@ -721,6 +793,7 @@ def create_ctas_table( # pylint: disable=too-many-locals encryption=encryption, kms_key=kms_key, boto3_session=boto3_session, + execution_params=execution_params, ) except botocore.exceptions.ClientError as ex: error = ex.response["Error"] diff --git a/docs/source/api.rst b/docs/source/api.rst index 85b8b1f70..d09f34f5c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -144,6 +144,9 @@ Amazon Athena to_iceberg unload wait_query + create_prepared_statement + list_prepared_statements + delete_prepared_statement AWS Lake Formation ------------------ diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index 6732a67c9..8153ec22f 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -1,6 +1,7 @@ import datetime import logging import string +from typing import Any from unittest.mock import patch import boto3 @@ -321,6 +322,56 @@ def test_athena_orc(path, glue_database, glue_table): assert_pandas_equals(df, df_out) +@pytest.mark.parametrize( + "ctas_approach,unload_approach", + [ + pytest.param(False, False, id="regular"), + pytest.param(True, False, id="ctas"), + pytest.param(False, True, id="unload"), + ], +) +@pytest.mark.parametrize( + "col_name,col_value", [("string", "Washington"), ("iint32", "1"), ("date", "DATE '2020-01-01'")] +) +def test_athena_paramstyle_qmark_parameters( + path: str, + path2: str, + glue_database: str, + glue_table: str, + workgroup0: str, + ctas_approach: bool, + unload_approach: bool, + col_name: str, + col_value: Any, +) -> None: + wr.s3.to_parquet( + df=get_df(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + df_out = wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table} WHERE {col_name} = ?", + database=glue_database, + ctas_approach=ctas_approach, + unload_approach=unload_approach, + workgroup=workgroup0, + params=[col_value], + paramstyle="qmark", + keep_files=False, + s3_output=path2, + ) + ensure_data_types(df=df_out) + ensure_athena_query_metadata(df=df_out, ctas_approach=ctas_approach, encrypted=False) + + assert len(df_out) == 1 + + def test_read_sql_query_parameter_formatting_respects_prefixes(path, glue_database, glue_table, workgroup0): wr.s3.to_parquet( df=get_df(), diff --git a/tests/unit/test_athena_prepared.py b/tests/unit/test_athena_prepared.py new file mode 100644 index 000000000..ad9b0cfb7 --- /dev/null +++ b/tests/unit/test_athena_prepared.py @@ -0,0 +1,190 @@ +import logging + +import boto3 +import pytest +from botocore.exceptions import ClientError + +import awswrangler as wr + +from .._utils import ( + ensure_athena_query_metadata, + ensure_data_types, + get_df, + get_time_str_with_random_suffix, +) + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + +pytestmark = pytest.mark.distributed + + +@pytest.fixture(scope="function") +def statement(workgroup0: str) -> str: + name = f"prepared_statement_{get_time_str_with_random_suffix()}" + yield name + try: + wr.athena.delete_prepared_statement(statement_name=name, workgroup=workgroup0) + except ClientError as e: + if e.response["Error"]["Code"] != "ResourceNotFoundException": + raise e + + +def test_update_prepared_statement(workgroup0: str, statement: str) -> None: + wr.athena.create_prepared_statement( + sql="SELECT 1 AS col0", + statement_name=statement, + workgroup=workgroup0, + ) + + wr.athena.create_prepared_statement( + sql="SELECT 1 AS col0, 2 AS col1", + statement_name=statement, + workgroup=workgroup0, + ) + + +def test_update_prepared_statement_error(workgroup0: str, statement: str) -> None: + wr.athena.create_prepared_statement( + sql="SELECT 1 AS col0", + statement_name=statement, + workgroup=workgroup0, + ) + + with pytest.raises(wr.exceptions.AlreadyExists): + wr.athena.create_prepared_statement( + sql="SELECT 1 AS col0, 2 AS col1", + statement_name=statement, + workgroup=workgroup0, + mode="error", + ) + + +def test_athena_deallocate_prepared_statement(workgroup0: str, statement: str) -> None: + athena_client = boto3.client("athena") + + sql_statement = "SELECT 1 as col0" + wr.athena.create_prepared_statement( + sql=sql_statement, + statement_name=statement, + workgroup=workgroup0, + ) + + resp = athena_client.get_prepared_statement(StatementName=statement, WorkGroup=workgroup0) + assert resp["PreparedStatement"]["QueryStatement"] == sql_statement + + wr.athena.delete_prepared_statement( + statement_name=statement, + workgroup=workgroup0, + ) + + +def test_list_prepared_statements(workgroup1: str, statement: str) -> None: + wr.athena.create_prepared_statement( + sql="SELECT 1 as col0", + statement_name=statement, + workgroup=workgroup1, + ) + + statement_list = wr.athena.list_prepared_statements(workgroup1) + + assert len(statement_list) == 1 + assert statement_list[0]["StatementName"] == statement + + wr.athena.delete_prepared_statement(statement, workgroup=workgroup1) + + statement_list = wr.athena.list_prepared_statements(workgroup1) + assert len(statement_list) == 0 + + +def test_athena_execute_prepared_statement( + path: str, + path2: str, + glue_database: str, + glue_table: str, + workgroup0: str, + statement: str, +) -> None: + wr.s3.to_parquet( + df=get_df(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + wr.athena.create_prepared_statement( + sql=f"SELECT * FROM {glue_table} WHERE string = ?", + statement_name=statement, + workgroup=workgroup0, + ) + + df_out1 = wr.athena.read_sql_query( + sql=f"EXECUTE \"{statement}\" USING 'Washington'", + database=glue_database, + ctas_approach=False, + workgroup=workgroup0, + keep_files=False, + s3_output=path2, + ) + df_out2 = wr.athena.read_sql_query( + sql=f"EXECUTE \"{statement}\" USING 'Seattle'", + database=glue_database, + ctas_approach=False, + workgroup=workgroup0, + keep_files=False, + s3_output=path2, + ) + + ensure_data_types(df=df_out1) + ensure_data_types(df=df_out2) + + ensure_athena_query_metadata(df=df_out1, ctas_approach=False, encrypted=False) + ensure_athena_query_metadata(df=df_out2, ctas_approach=False, encrypted=False) + + assert len(df_out1) == 1 + assert len(df_out2) == 1 + + +def test_athena_execute_prepared_statement_with_params( + path: str, + path2: str, + glue_database: str, + glue_table: str, + workgroup0: str, + statement: str, +) -> None: + wr.s3.to_parquet( + df=get_df(), + path=path, + index=False, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + wr.athena.create_prepared_statement( + sql=f"SELECT * FROM {glue_table} WHERE string = ?", + statement_name=statement, + workgroup=workgroup0, + ) + + df_out1 = wr.athena.read_sql_query( + sql=f'EXECUTE "{statement}"', + database=glue_database, + ctas_approach=False, + workgroup=workgroup0, + params=["Washington"], + paramstyle="qmark", + keep_files=False, + s3_output=path2, + ) + + ensure_data_types(df=df_out1) + ensure_athena_query_metadata(df=df_out1, ctas_approach=False, encrypted=False) + + assert len(df_out1) == 1 diff --git a/tutorials/006 - Amazon Athena.ipynb b/tutorials/006 - Amazon Athena.ipynb index b7c93cda2..84e5e1e7f 100644 --- a/tutorials/006 - Amazon Athena.ipynb +++ b/tutorials/006 - Amazon Athena.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -63,6 +64,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -81,6 +83,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -98,6 +101,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -160,6 +164,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -182,6 +187,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -204,6 +210,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -226,6 +233,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -249,6 +257,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -300,6 +309,141 @@ ] }, { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameterized queries" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Client-side parameter resolution\n", + "\n", + "The `params` parameter allows client-side resolution of parameters, which are specified with `:col_name`, when `paramstyle` is set to `named`.\n", + "Additionally, Python types will map to the appropriate Athena definitions.\n", + "For example, the value `dt.date(2023, 1, 1)` will resolve to `DATE '2023-01-01`.\n", + "\n", + "For the example below, the following query will be sent to Athena:\n", + "```sql\n", + "SELECT * FROM noaa WHERE S_FLAG = 'E'\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "wr.athena.read_sql_query(\n", + " \"SELECT * FROM noaa WHERE S_FLAG = :flag_value\",\n", + " database=\"awswrangler_test\",\n", + " params={\n", + " \"flag_value\": \"E\",\n", + " },\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Server-side parameter resolution\n", + "\n", + "Alternatively, Athena supports server-side parameter resolution when `paramstyle` is defined as `qmark`.\n", + "The SQL statement sent to Athena will not contain the values passed in `params`.\n", + "Instead, they will be passed as part of a separate `params` parameter in `boto3`.\n", + "\n", + "The downside of using this approach is that types aren't automatically resolved.\n", + "The values sent to `params` must be strings.\n", + "Therefore, if one of the values is a date, the value passed in `params` has to be `DATE 'XXXX-XX-XX'`.\n", + "\n", + "The upside, however, is that these parameters can be used with prepared statements.\n", + "\n", + "For more information, see \"[Using parameterized queries](https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html)\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "wr.athena.read_sql_query(\n", + " \"SELECT * FROM noaa WHERE S_FLAG = ?\",\n", + " database=\"awswrangler_test\",\n", + " params=[\"E\"],\n", + " paramstyle=\"qmark\",\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepared statements" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wr.athena.create_prepared_statement(\n", + " sql=\"SELECT * FROM noaa WHERE S_FLAG = ?\",\n", + " statement_name=\"statement\",\n", + ")\n", + "\n", + "# Resolve parameter using Athena execution parameters\n", + "wr.athena.read_sql_query(\n", + " sql=\"EXECUTE statement\",\n", + " database=\"awswrangler_test\",\n", + " params=[\"E\"],\n", + " paramstyle=\"qmark\",\n", + ")\n", + "\n", + "# Resolve parameter using Athena execution parameters (same effect as above)\n", + "wr.athena.read_sql_query(\n", + " sql=\"EXECUTE statement USING ?\",\n", + " database=\"awswrangler_test\",\n", + " params=[\"E\"],\n", + " paramstyle=\"qmark\",\n", + ")\n", + "\n", + "# Resolve parameter using client-side formatter\n", + "wr.athena.read_sql_query(\n", + " sql=\"EXECUTE statement USING :flag_value\",\n", + " database=\"awswrangler_test\",\n", + " params={\n", + " \"flag_value\": \"E\",\n", + " },\n", + " paramstyle=\"named\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Clean up prepared statement\n", + "wr.athena.delete_prepared_statement(statement_name=\"statement\")" + ] + }, + { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -320,6 +464,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -340,6 +485,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -376,7 +522,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.14" + "version": "3.9.13" } }, "nbformat": 4,