From 523061d0afa249aafb06eecb4e0e9feaad365581 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Fri, 10 Feb 2023 15:24:46 -0600 Subject: [PATCH 01/18] Extract Athena cache settings to TypedDict --- awswrangler/athena/_read.py | 72 +++++++++------------------------ awswrangler/athena/_utils.py | 26 +++++------- awswrangler/typing.py | 15 ++++--- tests/unit/test_athena_cache.py | 71 ++++++++++++++++++++++++++------ 4 files changed, 98 insertions(+), 86 deletions(-) diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index edc879479..4a7c0d7cc 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -732,10 +732,7 @@ def read_sql_query( # pylint: disable=too-many-arguments keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -762,10 +759,7 @@ def read_sql_query( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -792,10 +786,7 @@ def read_sql_query( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -822,10 +813,7 @@ def read_sql_query( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -852,10 +840,7 @@ def read_sql_query( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -881,10 +866,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals keep_files: bool = True, use_threads: Union[bool, int] = True, boto3_session: Optional[boto3.Session] = None, - max_cache_seconds: int = 0, - max_cache_query_inspections: int = 50, - max_remote_cache_entries: int = 50, - max_local_cache_entries: int = 100, + athena_cache_settings: Optional[typing.AthenaCacheSettings] = None, data_source: Optional[str] = None, params: Optional[Dict[str, Any]] = None, s3_additional_kwargs: Optional[Dict[str, Any]] = None, @@ -1108,6 +1090,12 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals raise exceptions.InvalidArgumentCombination("Only PARQUET file format is supported if unload_approach=True") chunksize = sys.maxsize if ctas_approach is False and chunksize is True else chunksize + athena_cache_settings = athena_cache_settings if athena_cache_settings else {} + max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0) + max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50) + max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50) + max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100) + # Substitute query parameters sql = _process_sql_params(sql, params) @@ -1188,10 +1176,7 @@ def read_sql_table( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -1217,10 +1202,7 @@ def read_sql_table( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -1246,10 +1228,7 @@ def read_sql_table( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -1275,10 +1254,7 @@ def read_sql_table( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -1304,10 +1280,7 @@ def read_sql_table( keep_files: bool = ..., use_threads: Union[bool, int] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., s3_additional_kwargs: Optional[Dict[str, Any]] = ..., pyarrow_additional_kwargs: Optional[Dict[str, Any]] = ..., @@ -1332,10 +1305,7 @@ def read_sql_table( keep_files: bool = True, use_threads: Union[bool, int] = True, boto3_session: Optional[boto3.Session] = None, - max_cache_seconds: int = 0, - max_cache_query_inspections: int = 50, - max_remote_cache_entries: int = 50, - max_local_cache_entries: int = 100, + athena_cache_settings: Optional[typing.AthenaCacheSettings] = None, data_source: Optional[str] = None, s3_additional_kwargs: Optional[Dict[str, Any]] = None, pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, @@ -1528,6 +1498,7 @@ def read_sql_table( """ table = catalog.sanitize_table_name(table=table) + return read_sql_query( sql=f'SELECT * FROM "{table}"', database=database, @@ -1544,10 +1515,7 @@ def read_sql_table( keep_files=keep_files, use_threads=use_threads, boto3_session=boto3_session, - max_cache_seconds=max_cache_seconds, - max_cache_query_inspections=max_cache_query_inspections, - max_remote_cache_entries=max_remote_cache_entries, - max_local_cache_entries=max_local_cache_entries, + athena_cache_settings=athena_cache_settings, data_source=data_source, s3_additional_kwargs=s3_additional_kwargs, pyarrow_additional_kwargs=pyarrow_additional_kwargs, diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index dff93c827..6671d44f7 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -393,10 +393,7 @@ def start_query_execution( kms_key: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., wait: Literal[False] = ..., ) -> str: @@ -414,10 +411,7 @@ def start_query_execution( kms_key: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., wait: Literal[True], ) -> Dict[str, Any]: @@ -435,10 +429,7 @@ def start_query_execution( kms_key: Optional[str] = ..., params: Optional[Dict[str, Any]] = ..., boto3_session: Optional[boto3.Session] = ..., - max_cache_seconds: int = ..., - max_cache_query_inspections: int = ..., - max_remote_cache_entries: int = ..., - max_local_cache_entries: int = ..., + athena_cache_settings: Optional[typing.AthenaCacheSettings] = ..., data_source: Optional[str] = ..., wait: bool, ) -> Union[str, Dict[str, Any]]: @@ -455,10 +446,7 @@ def start_query_execution( kms_key: Optional[str] = None, params: Optional[Dict[str, Any]] = None, boto3_session: Optional[boto3.Session] = None, - max_cache_seconds: int = 0, - max_cache_query_inspections: int = 50, - max_remote_cache_entries: int = 50, - max_local_cache_entries: int = 100, + athena_cache_settings: Optional[typing.AthenaCacheSettings] = None, data_source: Optional[str] = None, wait: bool = False, ) -> Union[str, Dict[str, Any]]: @@ -533,6 +521,12 @@ def start_query_execution( """ sql = _process_sql_params(sql, params) + athena_cache_settings = athena_cache_settings if athena_cache_settings else {} + max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0) + max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50) + max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50) + max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100) + max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries) _cache_manager.max_cache_size = max_local_cache_entries diff --git a/awswrangler/typing.py b/awswrangler/typing.py index 38e4ae8e8..e5b18770b 100644 --- a/awswrangler/typing.py +++ b/awswrangler/typing.py @@ -3,11 +3,7 @@ import sys from typing import Dict, List, Literal, Tuple, TypedDict -if sys.version_info >= (3, 11): - from typing import NotRequired, Required -else: - from typing_extensions import NotRequired, Required - +from typing_extensions import NotRequired, Required BucketingInfoTuple = Tuple[List[str], int] @@ -87,6 +83,15 @@ class AthenaUNLOADSettings(TypedDict): """ +class AthenaCacheSettings(TypedDict): + """Typed dictionary defining the settings for using cached Athena results.""" + + max_cache_seconds: NotRequired[int] + max_cache_query_inspections: NotRequired[int] + max_remote_cache_entries: NotRequired[int] + max_local_cache_entries: NotRequired[int] + + class _S3WriteDataReturnValue(TypedDict): """Typed dictionary defining the dictionary returned by S3 write functions.""" diff --git a/tests/unit/test_athena_cache.py b/tests/unit/test_athena_cache.py index 8ecdfd002..f21dbc0d8 100644 --- a/tests/unit/test_athena_cache.py +++ b/tests/unit/test_athena_cache.py @@ -20,19 +20,32 @@ def test_athena_cache(wr, path, glue_database, glue_table, workgroup1): wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table) df2 = wr.athena.read_sql_table( - glue_table, glue_database, ctas_approach=False, max_cache_seconds=1, workgroup=workgroup1 + glue_table, + glue_database, + ctas_approach=False, + workgroup=workgroup1, + athena_cache_settings={"max_cache_seconds": 1}, ) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() df2 = wr.athena.read_sql_table( - glue_table, glue_database, ctas_approach=False, max_cache_seconds=900, workgroup=workgroup1 + glue_table, + glue_database, + ctas_approach=False, + athena_cache_settings={"max_cache_seconds": 900}, + workgroup=workgroup1, ) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() dfs = wr.athena.read_sql_table( - glue_table, glue_database, ctas_approach=False, max_cache_seconds=900, workgroup=workgroup1, chunksize=1 + glue_table, + glue_database, + ctas_approach=False, + athena_cache_settings={"max_cache_seconds": 900}, + workgroup=workgroup1, + chunksize=1, ) assert len(list(dfs)) == 2 @@ -59,7 +72,11 @@ def test_cache_query_ctas_approach_true(wr, path, glue_database, glue_table, dat return_value=wr.athena._read._CacheInfo(has_valid_cache=False), ) as mocked_cache_attempt: df2 = wr.athena.read_sql_table( - glue_table, glue_database, ctas_approach=True, max_cache_seconds=0, data_source=data_source + glue_table, + glue_database, + ctas_approach=True, + athena_cache_settings={"max_cache_seconds": 0}, + data_source=data_source, ) mocked_cache_attempt.assert_called() assert df.shape == df2.shape @@ -67,7 +84,11 @@ def test_cache_query_ctas_approach_true(wr, path, glue_database, glue_table, dat with patch("awswrangler.athena._read._resolve_query_without_cache") as resolve_no_cache: df3 = wr.athena.read_sql_table( - glue_table, glue_database, ctas_approach=True, max_cache_seconds=900, data_source=data_source + glue_table, + glue_database, + ctas_approach=True, + athena_cache_settings={"max_cache_seconds": 900}, + data_source=data_source, ) resolve_no_cache.assert_not_called() assert df.shape == df3.shape @@ -97,7 +118,11 @@ def test_cache_query_ctas_approach_false(wr, path, glue_database, glue_table, da return_value=wr.athena._read._CacheInfo(has_valid_cache=False), ) as mocked_cache_attempt: df2 = wr.athena.read_sql_table( - glue_table, glue_database, ctas_approach=False, max_cache_seconds=0, data_source=data_source + glue_table, + glue_database, + ctas_approach=False, + athena_cache_settings={"max_cache_seconds": 0}, + data_source=data_source, ) mocked_cache_attempt.assert_called() assert df.shape == df2.shape @@ -105,7 +130,11 @@ def test_cache_query_ctas_approach_false(wr, path, glue_database, glue_table, da with patch("awswrangler.athena._read._resolve_query_without_cache") as resolve_no_cache: df3 = wr.athena.read_sql_table( - glue_table, glue_database, ctas_approach=False, max_cache_seconds=900, data_source=data_source + glue_table, + glue_database, + ctas_approach=False, + athena_cache_settings={"max_cache_seconds": 900}, + data_source=data_source, ) resolve_no_cache.assert_not_called() assert df.shape == df3.shape @@ -122,7 +151,10 @@ def test_cache_query_semicolon(wr, path, glue_database, glue_table): return_value=wr.athena._read._CacheInfo(has_valid_cache=False), ) as mocked_cache_attempt: df2 = wr.athena.read_sql_query( - f"SELECT * FROM {glue_table}", database=glue_database, ctas_approach=True, max_cache_seconds=0 + f"SELECT * FROM {glue_table}", + database=glue_database, + ctas_approach=True, + athena_cache_settings={"max_cache_seconds": 0}, ) mocked_cache_attempt.assert_called() assert df.shape == df2.shape @@ -130,7 +162,10 @@ def test_cache_query_semicolon(wr, path, glue_database, glue_table): with patch("awswrangler.athena._read._resolve_query_without_cache") as resolve_no_cache: df3 = wr.athena.read_sql_query( - f"SELECT * FROM {glue_table};", database=glue_database, ctas_approach=True, max_cache_seconds=900 + f"SELECT * FROM {glue_table};", + database=glue_database, + ctas_approach=True, + athena_cache_settings={"max_cache_seconds": 900}, ) resolve_no_cache.assert_not_called() assert df.shape == df3.shape @@ -148,7 +183,10 @@ def test_local_cache(wr, path, glue_database, glue_table): return_value=wr.athena._read._CacheInfo(has_valid_cache=False), ) as mocked_cache_attempt: df2 = wr.athena.read_sql_query( - f"SELECT * FROM {glue_table}", database=glue_database, ctas_approach=True, max_cache_seconds=0 + f"SELECT * FROM {glue_table}", + database=glue_database, + ctas_approach=True, + athena_cache_settings={"max_cache_seconds": 0}, ) mocked_cache_attempt.assert_called() assert df.shape == df2.shape @@ -157,7 +195,10 @@ def test_local_cache(wr, path, glue_database, glue_table): assert first_query_id in wr.athena._read._cache_manager df3 = wr.athena.read_sql_query( - f"SELECT * FROM {glue_table}", database=glue_database, ctas_approach=True, max_cache_seconds=0 + f"SELECT * FROM {glue_table}", + database=glue_database, + ctas_approach=True, + athena_cache_settings={"max_cache_seconds": 0}, ) mocked_cache_attempt.assert_called() assert df.shape == df3.shape @@ -174,7 +215,11 @@ def test_paginated_remote_cache(wr, path, glue_database, glue_table, workgroup1) wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table) df2 = wr.athena.read_sql_table( - glue_table, glue_database, ctas_approach=False, max_cache_seconds=1, workgroup=workgroup1 + glue_table, + glue_database, + ctas_approach=False, + athena_cache_settings={"max_cache_seconds": 1}, + workgroup=workgroup1, ) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() @@ -209,7 +254,7 @@ def test_cache_start_query(wr, path, glue_database, glue_table, data_source): with patch("awswrangler.athena._utils._start_query_execution") as internal_start_query: query_id_2 = wr.athena.start_query_execution( - sql=f"SELECT * FROM {glue_table}", database=glue_database, max_cache_seconds=900 + sql=f"SELECT * FROM {glue_table}", database=glue_database, athena_cache_settings={"max_cache_seconds": 900} ) internal_start_query.assert_not_called() assert query_id == query_id_2 From c535279dfdbf9e1083198ddcfc21d79bde9ea281 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 13 Feb 2023 10:49:10 -0600 Subject: [PATCH 02/18] Refactor _ConfigValueType --- awswrangler/_config.py | 20 ++++++++++---------- awswrangler/typing.py | 1 - 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 0710b5972..edf6752cf 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -13,11 +13,11 @@ _logger: logging.Logger = logging.getLogger(__name__) -_ConfigValueType = Union[str, bool, int, botocore.config.Config, None] +_ConfigValueType = Union[str, bool, int, botocore.config.Config] class _ConfigArg(NamedTuple): - dtype: Type[Union[str, bool, int, botocore.config.Config]] + dtype: Type[_ConfigValueType] nullable: bool enforced: bool = False loaded: bool = False @@ -71,7 +71,7 @@ class _Config: # pylint: disable=too-many-instance-attributes,too-many-public-m """AWS Wrangler's Configuration class.""" def __init__(self) -> None: - self._loaded_values: Dict[str, _ConfigValueType] = {} + self._loaded_values: Dict[str, Optional[_ConfigValueType]] = {} name: str self.s3_endpoint_url = None self.athena_endpoint_url = None @@ -134,7 +134,7 @@ def to_pandas(self) -> pd.DataFrame: for k, v in _CONFIG_ARGS.items(): arg: Dict[str, Any] = { "name": k, - "Env. Variable": f"WR_{k.upper()}", + "Env.Variable": f"WR_{k.upper()}", "type": v.dtype, "nullable": v.nullable, "enforced": v.enforced, @@ -164,15 +164,15 @@ def _set_config_value(self, key: str, value: Any) -> None: raise exceptions.InvalidArgumentValue( f"{key} is not a valid configuration. Please use: {list(_CONFIG_ARGS.keys())}" ) - value_casted: _ConfigValueType = self._apply_type( + value_casted = self._apply_type( name=key, value=value, - dtype=_CONFIG_ARGS[key].dtype, # type: ignore[arg-type] + dtype=_CONFIG_ARGS[key].dtype, nullable=_CONFIG_ARGS[key].nullable, ) self._loaded_values[key] = value_casted - def __getitem__(self, item: str) -> _ConfigValueType: + def __getitem__(self, item: str) -> Optional[_ConfigValueType]: if item not in self._loaded_values: raise AttributeError(f"{item} not configured yet.") return self._loaded_values[item] @@ -189,7 +189,7 @@ def _repr_html_(self) -> Any: return self.to_pandas().to_html() @staticmethod - def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nullable: bool) -> _ConfigValueType: + def _apply_type(name: str, value: Any, dtype: Type[_ConfigValueType], nullable: bool) -> Optional[_ConfigValueType]: if _Config._is_null(value=value): if nullable is True: return None @@ -202,7 +202,7 @@ def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nulla raise exceptions.InvalidConfiguration(f"Config {name} must receive a {dtype} value.") from ex @staticmethod - def _is_null(value: _ConfigValueType) -> bool: + def _is_null(value: Optional[_ConfigValueType]) -> bool: if value is None: return True if isinstance(value, str) is True: @@ -589,7 +589,7 @@ def wrapper(*args_raw: Any, **kwargs: Any) -> Any: args: Dict[str, Any] = signature.bind_partial(*args_raw, **kwargs).arguments for name in available_configs: if hasattr(config, name) is True: - value: _ConfigValueType = config[name] + value = config[name] if name not in args: _logger.debug("Applying default config argument %s with value %s.", name, value) args[name] = value diff --git a/awswrangler/typing.py b/awswrangler/typing.py index e5b18770b..44d2ba851 100644 --- a/awswrangler/typing.py +++ b/awswrangler/typing.py @@ -1,6 +1,5 @@ """Module with parameter types.""" -import sys from typing import Dict, List, Literal, Tuple, TypedDict from typing_extensions import NotRequired, Required From 15b437f929edd0042cfbacdf981fcf967f68c0b5 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 13 Feb 2023 13:06:17 -0600 Subject: [PATCH 03/18] Add support for nested args --- awswrangler/_config.py | 64 +++++++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index edf6752cf..c40790aa3 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -9,11 +9,12 @@ import pandas as pd from awswrangler import exceptions +from awswrangler.typing import AthenaCacheSettings _logger: logging.Logger = logging.getLogger(__name__) -_ConfigValueType = Union[str, bool, int, botocore.config.Config] +_ConfigValueType = Union[str, bool, int, botocore.config.Config, dict] class _ConfigArg(NamedTuple): @@ -22,6 +23,8 @@ class _ConfigArg(NamedTuple): enforced: bool = False loaded: bool = False default: Optional[_ConfigValueType] = None + parent_key: Optional[str] = None + is_parent: bool = False # Please, also add any new argument as a property in the _Config class @@ -30,10 +33,11 @@ class _ConfigArg(NamedTuple): "concurrent_partitioning": _ConfigArg(dtype=bool, nullable=False), "ctas_approach": _ConfigArg(dtype=bool, nullable=False), "database": _ConfigArg(dtype=str, nullable=True), - "max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False), - "max_cache_seconds": _ConfigArg(dtype=int, nullable=False), - "max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False), - "max_local_cache_entries": _ConfigArg(dtype=int, nullable=False), + "athena_cache_settings": _ConfigArg(dtype=dict, nullable=False, is_parent=True), + "max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False, parent_key="athena_cache_settings"), + "max_cache_seconds": _ConfigArg(dtype=int, nullable=False, parent_key="athena_cache_settings"), + "max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_key="athena_cache_settings"), + "max_local_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_key="athena_cache_settings"), "s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True), "workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True), "chunksize": _ConfigArg(dtype=int, nullable=False, enforced=True), @@ -149,14 +153,19 @@ def to_pandas(self) -> pd.DataFrame: return pd.DataFrame(args) def _load_config(self, name: str) -> bool: + if _CONFIG_ARGS[name].is_parent: + return False + loaded_config: bool = False if _CONFIG_ARGS[name].loaded: self._set_config_value(key=name, value=_CONFIG_ARGS[name].default) loaded_config = True + env_var: Optional[str] = os.getenv(f"WR_{name.upper()}") if env_var is not None: self._set_config_value(key=name, value=env_var) loaded_config = True + return loaded_config def _set_config_value(self, key: str, value: Any) -> None: @@ -170,7 +179,15 @@ def _set_config_value(self, key: str, value: Any) -> None: dtype=_CONFIG_ARGS[key].dtype, nullable=_CONFIG_ARGS[key].nullable, ) - self._loaded_values[key] = value_casted + + parent_key = _CONFIG_ARGS[key].parent_key + if parent_key: + if self._loaded_values.get(parent_key) is None: + self._loaded_values[parent_key] = {} + + self._loaded_values[parent_key][key] = value_casted # type: ignore[index] + else: + self._loaded_values[key] = value_casted def __getitem__(self, item: str) -> Optional[_ConfigValueType]: if item not in self._loaded_values: @@ -247,10 +264,14 @@ def database(self) -> Optional[str]: def database(self, value: Optional[str]) -> None: self._set_config_value(key="database", value=value) + @property + def athena_cache_settings(self) -> AthenaCacheSettings: + return cast(AthenaCacheSettings, self._loaded_values.get("athena_cache_settings")) + @property def max_cache_query_inspections(self) -> int: """Property max_cache_query_inspections.""" - return cast(int, self["max_cache_query_inspections"]) + return self.athena_cache_settings["max_cache_query_inspections"] @max_cache_query_inspections.setter def max_cache_query_inspections(self, value: int) -> None: @@ -259,7 +280,7 @@ def max_cache_query_inspections(self, value: int) -> None: @property def max_cache_seconds(self) -> int: """Property max_cache_seconds.""" - return cast(int, self["max_cache_seconds"]) + return self.athena_cache_settings["max_cache_seconds"] @max_cache_seconds.setter def max_cache_seconds(self, value: int) -> None: @@ -268,7 +289,7 @@ def max_cache_seconds(self, value: int) -> None: @property def max_local_cache_entries(self) -> int: """Property max_local_cache_entries.""" - return cast(int, self["max_local_cache_entries"]) + return self.athena_cache_settings["max_local_cache_entries"] @max_local_cache_entries.setter def max_local_cache_entries(self, value: int) -> None: @@ -288,7 +309,7 @@ def max_local_cache_entries(self, value: int) -> None: @property def max_remote_cache_entries(self) -> int: """Property max_remote_cache_entries.""" - return cast(int, self["max_remote_cache_entries"]) + return self.athena_cache_settings["max_remote_cache_entries"] @max_remote_cache_entries.setter def max_remote_cache_entries(self, value: int) -> None: @@ -579,6 +600,22 @@ def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) - FunctionType = TypeVar("FunctionType", bound=Callable[..., Any]) +def _assign_args_value(args: Dict[str, Any], name: str, value: Any) -> None: + if _CONFIG_ARGS[name].is_parent: + nested_args = cast(Dict[str, Any], value) + for nested_arg_name, nested_arg_value in nested_args.items(): + _assign_args_value(args[name], nested_arg_name, nested_arg_value) + return + + if name not in args: + _logger.debug("Applying default config argument %s with value %s.", name, value) + args[name] = value + + elif _CONFIG_ARGS[name].enforced is True: + _logger.debug("Applying ENFORCED config argument %s with value %s.", name, value) + args[name] = value + + def apply_configs(function: FunctionType) -> FunctionType: """Decorate some function with configs.""" signature = inspect.signature(function) @@ -590,12 +627,7 @@ def wrapper(*args_raw: Any, **kwargs: Any) -> Any: for name in available_configs: if hasattr(config, name) is True: value = config[name] - if name not in args: - _logger.debug("Applying default config argument %s with value %s.", name, value) - args[name] = value - elif _CONFIG_ARGS[name].enforced is True: - _logger.debug("Applying ENFORCED config argument %s with value %s.", name, value) - args[name] = value + _assign_args_value(args, name, value) for name, param in signature.parameters.items(): if param.kind == param.VAR_KEYWORD and name in args: if isinstance(args[name], dict) is False: From 8b673e2f7bce0eb566527772b130dd469bb799cd Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 13 Feb 2023 15:10:15 -0600 Subject: [PATCH 04/18] Fix getters and setters --- awswrangler/_config.py | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index c40790aa3..37475b96c 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -23,7 +23,7 @@ class _ConfigArg(NamedTuple): enforced: bool = False loaded: bool = False default: Optional[_ConfigValueType] = None - parent_key: Optional[str] = None + parent_parameter_key: Optional[str] = None is_parent: bool = False @@ -34,10 +34,10 @@ class _ConfigArg(NamedTuple): "ctas_approach": _ConfigArg(dtype=bool, nullable=False), "database": _ConfigArg(dtype=str, nullable=True), "athena_cache_settings": _ConfigArg(dtype=dict, nullable=False, is_parent=True), - "max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False, parent_key="athena_cache_settings"), - "max_cache_seconds": _ConfigArg(dtype=int, nullable=False, parent_key="athena_cache_settings"), - "max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_key="athena_cache_settings"), - "max_local_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_key="athena_cache_settings"), + "max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), + "max_cache_seconds": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), + "max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), + "max_local_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), "s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True), "workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True), "chunksize": _ConfigArg(dtype=int, nullable=False, enforced=True), @@ -136,12 +136,16 @@ def to_pandas(self) -> pd.DataFrame: """ args: List[Dict[str, Any]] = [] for k, v in _CONFIG_ARGS.items(): + if v.is_parent: + continue + arg: Dict[str, Any] = { "name": k, "Env.Variable": f"WR_{k.upper()}", "type": v.dtype, "nullable": v.nullable, "enforced": v.enforced, + "parent_parameter_name": v.parent_parameter_key, } if k in self._loaded_values: arg["configured"] = True @@ -180,7 +184,7 @@ def _set_config_value(self, key: str, value: Any) -> None: nullable=_CONFIG_ARGS[key].nullable, ) - parent_key = _CONFIG_ARGS[key].parent_key + parent_key = _CONFIG_ARGS[key].parent_parameter_key if parent_key: if self._loaded_values.get(parent_key) is None: self._loaded_values[parent_key] = {} @@ -190,9 +194,20 @@ def _set_config_value(self, key: str, value: Any) -> None: self._loaded_values[key] = value_casted def __getitem__(self, item: str) -> Optional[_ConfigValueType]: - if item not in self._loaded_values: + if _CONFIG_ARGS[item].is_parent: + return self._loaded_values.get(item, {}) + + loaded_values: Dict[str, Optional[_ConfigValueType]] + parent_key = _CONFIG_ARGS[item].parent_parameter_key + if parent_key: + loaded_values = self[parent_key] # type: ignore[assignment] + else: + loaded_values = self._loaded_values + + if item not in loaded_values: raise AttributeError(f"{item} not configured yet.") - return self._loaded_values[item] + + return loaded_values[item] def _reset_item(self, item: str) -> None: if item in self._loaded_values: @@ -266,12 +281,12 @@ def database(self, value: Optional[str]) -> None: @property def athena_cache_settings(self) -> AthenaCacheSettings: - return cast(AthenaCacheSettings, self._loaded_values.get("athena_cache_settings")) + return cast(AthenaCacheSettings, self["athena_cache_settings"]) @property def max_cache_query_inspections(self) -> int: """Property max_cache_query_inspections.""" - return self.athena_cache_settings["max_cache_query_inspections"] + return cast(int, self["max_cache_query_inspections"]) @max_cache_query_inspections.setter def max_cache_query_inspections(self, value: int) -> None: @@ -280,7 +295,7 @@ def max_cache_query_inspections(self, value: int) -> None: @property def max_cache_seconds(self) -> int: """Property max_cache_seconds.""" - return self.athena_cache_settings["max_cache_seconds"] + return cast(int, self["max_cache_seconds"]) @max_cache_seconds.setter def max_cache_seconds(self, value: int) -> None: @@ -289,7 +304,7 @@ def max_cache_seconds(self, value: int) -> None: @property def max_local_cache_entries(self) -> int: """Property max_local_cache_entries.""" - return self.athena_cache_settings["max_local_cache_entries"] + return cast(int, self["max_local_cache_entries"]) @max_local_cache_entries.setter def max_local_cache_entries(self, value: int) -> None: @@ -309,7 +324,7 @@ def max_local_cache_entries(self, value: int) -> None: @property def max_remote_cache_entries(self) -> int: """Property max_remote_cache_entries.""" - return self.athena_cache_settings["max_remote_cache_entries"] + return cast(int, self["max_remote_cache_entries"]) @max_remote_cache_entries.setter def max_remote_cache_entries(self, value: int) -> None: @@ -628,6 +643,7 @@ def wrapper(*args_raw: Any, **kwargs: Any) -> Any: if hasattr(config, name) is True: value = config[name] _assign_args_value(args, name, value) + for name, param in signature.parameters.items(): if param.kind == param.VAR_KEYWORD and name in args: if isinstance(args[name], dict) is False: From 162c8512d919814f7245e86b1efaeb5495801f57 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 13 Feb 2023 15:15:19 -0600 Subject: [PATCH 05/18] fix formatting --- awswrangler/_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 37475b96c..0b5d2eec5 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -200,7 +200,7 @@ def __getitem__(self, item: str) -> Optional[_ConfigValueType]: loaded_values: Dict[str, Optional[_ConfigValueType]] parent_key = _CONFIG_ARGS[item].parent_parameter_key if parent_key: - loaded_values = self[parent_key] # type: ignore[assignment] + loaded_values = self[parent_key] # type: ignore[assignment] else: loaded_values = self._loaded_values From 71d3c42c3d2e66a288379c3e4f87037f3efae0e4 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 13 Feb 2023 17:09:30 -0600 Subject: [PATCH 06/18] Refactor config --- awswrangler/_config.py | 18 +++++++++--------- tests/unit/test_config.py | 20 ++++++++++++++++++-- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 0b5d2eec5..4f2c0ceb9 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -33,7 +33,7 @@ class _ConfigArg(NamedTuple): "concurrent_partitioning": _ConfigArg(dtype=bool, nullable=False), "ctas_approach": _ConfigArg(dtype=bool, nullable=False), "database": _ConfigArg(dtype=str, nullable=True), - "athena_cache_settings": _ConfigArg(dtype=dict, nullable=False, is_parent=True), + "athena_cache_settings": _ConfigArg(dtype=dict, nullable=False, is_parent=True, default={}, loaded=True), "max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), "max_cache_seconds": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), "max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), @@ -158,6 +158,9 @@ def to_pandas(self) -> pd.DataFrame: def _load_config(self, name: str) -> bool: if _CONFIG_ARGS[name].is_parent: + if self._loaded_values.get(name) is None: + self._set_config_value(key=name, value={}) + return True return False loaded_config: bool = False @@ -186,16 +189,13 @@ def _set_config_value(self, key: str, value: Any) -> None: parent_key = _CONFIG_ARGS[key].parent_parameter_key if parent_key: - if self._loaded_values.get(parent_key) is None: - self._loaded_values[parent_key] = {} - self._loaded_values[parent_key][key] = value_casted # type: ignore[index] else: self._loaded_values[key] = value_casted def __getitem__(self, item: str) -> Optional[_ConfigValueType]: - if _CONFIG_ARGS[item].is_parent: - return self._loaded_values.get(item, {}) + if issubclass(_CONFIG_ARGS[item].dtype, dict): + return self._loaded_values[item] loaded_values: Dict[str, Optional[_ConfigValueType]] parent_key = _CONFIG_ARGS[item].parent_parameter_key @@ -612,9 +612,6 @@ def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) - return _insert_str(text=doc, token="\n Parameters", insert=insertion) -FunctionType = TypeVar("FunctionType", bound=Callable[..., Any]) - - def _assign_args_value(args: Dict[str, Any], name: str, value: Any) -> None: if _CONFIG_ARGS[name].is_parent: nested_args = cast(Dict[str, Any], value) @@ -631,6 +628,9 @@ def _assign_args_value(args: Dict[str, Any], name: str, value: Any) -> None: args[name] = value +FunctionType = TypeVar("FunctionType", bound=Callable[..., Any]) + + def apply_configs(function: FunctionType) -> FunctionType: """Decorate some function with configs.""" signature = inspect.signature(function) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index a70c2a785..f7daf6605 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -131,8 +131,24 @@ def test_basics(path, glue_database, glue_table, workgroup0, workgroup1): def test_athena_cache_configuration(): wr.config.max_remote_cache_entries = 50 - wr.config.max_local_cache_entries = 20 - assert wr.config.max_remote_cache_entries == 20 + wr.config.max_cache_seconds = 20 + + assert wr.config.max_remote_cache_entries == 50 + assert wr.config.athena_cache_settings["max_remote_cache_entries"] == 50 + + assert wr.config.max_cache_seconds == 20 + assert wr.config.athena_cache_settings["max_cache_seconds"] == 20 + + +def test_athena_cache_configuration_dict(): + wr.config.athena_cache_settings["max_remote_cache_entries"] = 50 + wr.config.athena_cache_settings["max_cache_seconds"] = 20 + + assert wr.config.max_remote_cache_entries == 50 + assert wr.config.athena_cache_settings["max_remote_cache_entries"] == 50 + + assert wr.config.max_cache_seconds == 20 + assert wr.config.athena_cache_settings["max_cache_seconds"] == 20 def test_botocore_config(path): From 91262dc152ba24e1ad229b86c2e6eec05a51a8d5 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 13 Feb 2023 17:13:16 -0600 Subject: [PATCH 07/18] Fix config shape --- tests/unit/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index f7daf6605..d51157709 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -101,7 +101,7 @@ def test_basics(path, glue_database, glue_table, workgroup0, workgroup1): with pytest.raises(TypeError): wr.catalog.does_table_exist(table=glue_table) - assert wr.config.to_pandas().shape == (len(wr._config._CONFIG_ARGS), 7) + assert wr.config.to_pandas().shape == (len(wr._config._CONFIG_ARGS) - 1, 8) # Workgroup wr.config.workgroup = workgroup0 From ba0ab7d2e883e28de9e70cd29c3afd532914ef88 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 14 Feb 2023 10:08:21 -0600 Subject: [PATCH 08/18] Fix arg --- awswrangler/_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 4f2c0ceb9..81612d827 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -614,6 +614,9 @@ def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) - def _assign_args_value(args: Dict[str, Any], name: str, value: Any) -> None: if _CONFIG_ARGS[name].is_parent: + if name not in args: + args[name] = {} + nested_args = cast(Dict[str, Any], value) for nested_arg_name, nested_arg_value in nested_args.items(): _assign_args_value(args[name], nested_arg_name, nested_arg_value) From 53df867dfd11d6aeae59aa132a132e8cf3827bb1 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 14 Feb 2023 10:31:03 -0600 Subject: [PATCH 09/18] Isolate config unit tests --- tests/unit/test_config.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index d51157709..85e067512 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,5 +1,6 @@ import logging import os +from types import ModuleType from unittest.mock import create_autospec, patch import boto3 @@ -8,14 +9,13 @@ import botocore.config import pytest -import awswrangler as wr from awswrangler._config import apply_configs from awswrangler.s3._fs import open_s3_object logging.getLogger("awswrangler").setLevel(logging.DEBUG) -def _urls_test(glue_database): +def _urls_test(wr: ModuleType, glue_database: str) -> None: original = botocore.client.ClientCreator.create_client def wrapper(self, **kwarg): @@ -41,7 +41,9 @@ def wrapper(self, **kwarg): wr.athena.read_sql_query(sql="SELECT 1 as col0", database=glue_database) -def test_basics(path, glue_database, glue_table, workgroup0, workgroup1): +def test_basics( + wr: ModuleType, path: str, glue_database: str, glue_table: str, workgroup0: str, workgroup1: str +) -> None: args = {"table": glue_table, "path": "", "columns_types": {"col0": "bigint"}} # Missing database argument @@ -129,7 +131,7 @@ def test_basics(path, glue_database, glue_table, workgroup0, workgroup1): _urls_test(glue_database) -def test_athena_cache_configuration(): +def test_athena_cache_configuration(wr: ModuleType) -> None: wr.config.max_remote_cache_entries = 50 wr.config.max_cache_seconds = 20 @@ -140,7 +142,7 @@ def test_athena_cache_configuration(): assert wr.config.athena_cache_settings["max_cache_seconds"] == 20 -def test_athena_cache_configuration_dict(): +def test_athena_cache_configuration_dict(wr: ModuleType) -> None: wr.config.athena_cache_settings["max_remote_cache_entries"] = 50 wr.config.athena_cache_settings["max_cache_seconds"] = 20 @@ -151,7 +153,7 @@ def test_athena_cache_configuration_dict(): assert wr.config.athena_cache_settings["max_cache_seconds"] == 20 -def test_botocore_config(path): +def test_botocore_config(wr: ModuleType, path: str) -> None: original = botocore.client.ClientCreator.create_client # Default values for botocore.config.Config @@ -208,7 +210,7 @@ def wrapper(self, **kwarg): wr.config.reset() -def test_chunk_size(): +def test_chunk_size(wr: ModuleType) -> None: expected_chunksize = 123 wr.config.chunksize = expected_chunksize From 7947acde5ad0c863eafd5d1584dc3a7c52891c4c Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 14 Feb 2023 10:31:42 -0600 Subject: [PATCH 10/18] Refactor Athena cache tutorial --- tutorials/019 - Athena Cache.ipynb | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tutorials/019 - Athena Cache.ipynb b/tutorials/019 - Athena Cache.ipynb index c097efcc9..5f8db3d41 100644 --- a/tutorials/019 - Athena Cache.ipynb +++ b/tutorials/019 - Athena Cache.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -8,7 +9,7 @@ "\n", "# 19 - Amazon Athena Cache\n", "\n", - "[awswrangler](https://github.com/aws/aws-sdk-pandas) has a cache strategy that is disabled by default and can be enabled by passing `max_cache_seconds` bigger than 0. This cache strategy for Amazon Athena can help you to **decrease query times and costs**.\n", + "[awswrangler](https://github.com/aws/aws-sdk-pandas) has a cache strategy that is disabled by default and can be enabled by passing `max_cache_seconds` bigger than 0 as part of the `athena_cache_settings` parameter. This cache strategy for Amazon Athena can help you to **decrease query times and costs**.\n", "\n", "When calling `read_sql_query`, instead of just running the query, we now can verify if the query has been run before. If so, and this last run was within `max_cache_seconds` (a new parameter to `read_sql_query`), we return the same results as last time if they are still available in S3. We have seen this increase performance more than 100x, but the potential is pretty much infinite.\n", "\n", @@ -44,7 +45,7 @@ "metadata": {}, "outputs": [ { - "name": "stdin", + "name": "stdout", "output_type": "stream", "text": [ " ···········································\n" @@ -1189,7 +1190,7 @@ "source": [ "%%time\n", "\n", - "wr.athena.read_sql_query(query, database=\"awswrangler_test\", max_cache_seconds=900, max_cache_query_inspections=500)" + "wr.athena.read_sql_query(query, database=\"awswrangler_test\", athena_cache_settings={\"max_cache_seconds\": 900, \"max_cache_query_inspections\": 500})" ] }, { From 532bb915b2732cce4cb1c607443b0b08d7ff29d9 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 14 Feb 2023 10:38:51 -0600 Subject: [PATCH 11/18] Update documentation --- awswrangler/_config.py | 1 + awswrangler/athena/_read.py | 40 ++++++------------------------------ awswrangler/athena/_utils.py | 23 +++++---------------- awswrangler/typing.py | 21 +++++++++++++++++++ 4 files changed, 33 insertions(+), 52 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 81612d827..d8102c54e 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -281,6 +281,7 @@ def database(self, value: Optional[str]) -> None: @property def athena_cache_settings(self) -> AthenaCacheSettings: + """Property athena_cache_settings.""" return cast(AthenaCacheSettings, self["athena_cache_settings"]) @property diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 4a7c0d7cc..25f412631 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -1024,26 +1024,12 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals If integer is provided, specified number is used. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. - max_cache_seconds : int - awswrangler can look up in Athena's history if this query has been run before. - If so, and its completion time is less than `max_cache_seconds` before now, awswrangler - skips query execution and just returns the same results as last time. + athena_cache_settings: typing.AthenaCacheSettings, optional + Params of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections, + max_remote_cache_entries, and max_local_cache_entries. If cached results are valid, awswrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`, `keep_files` and `ctas_temp_table_name` params. If reading cached data fails for any reason, execution falls back to the usual query run path. - max_cache_query_inspections : int - Max number of queries that will be inspected from the history to try to find some result to reuse. - The bigger the number of inspection, the bigger will be the latency for not cached queries. - Only takes effect if max_cache_seconds > 0. - max_remote_cache_entries : int - Max number of queries that will be retrieved from AWS for cache inspection. - The bigger the number of inspection, the bigger will be the latency for not cached queries. - Only takes effect if max_cache_seconds > 0 and default value is 50. - max_local_cache_entries : int - Max number of queries for which metadata will be cached locally. This will reduce the latency and also - enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be - smaller than max_remote_cache_entries. - Only takes effect if max_cache_seconds > 0 and default value is 100. data_source : str, optional Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default. params: Dict[str, any], optional @@ -1455,26 +1441,12 @@ def read_sql_table( If integer is provided, specified number is used. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. - max_cache_seconds: int - awswrangler can look up in Athena's history if this table has been read before. - If so, and its completion time is less than `max_cache_seconds` before now, awswrangler - skips query execution and just returns the same results as last time. + athena_cache_settings: typing.AthenaCacheSettings, optional + Params of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections, + max_remote_cache_entries, and max_local_cache_entries. If cached results are valid, awswrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`, `keep_files` and `ctas_temp_table_name` params. If reading cached data fails for any reason, execution falls back to the usual query run path. - max_cache_query_inspections : int - Max number of queries that will be inspected from the history to try to find some result to reuse. - The bigger the number of inspection, the bigger will be the latency for not cached queries. - Only takes effect if max_cache_seconds > 0. - max_remote_cache_entries : int - Max number of queries that will be retrieved from AWS for cache inspection. - The bigger the number of inspection, the bigger will be the latency for not cached queries. - Only takes effect if max_cache_seconds > 0 and default value is 50. - max_local_cache_entries : int - Max number of queries for which metadata will be cached locally. This will reduce the latency and also - enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be - smaller than max_remote_cache_entries. - Only takes effect if max_cache_seconds > 0 and default value is 100. data_source : str, optional Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default. s3_additional_kwargs : Optional[Dict[str, Any]] diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 6671d44f7..933045852 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -477,25 +477,12 @@ def start_query_execution( `:name`. Note that for varchar columns and similar, you must surround the value in single quotes. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. - max_cache_seconds: int - awswrangler can look up in Athena's history if this query has been run before. - If so, and its completion time is less than `max_cache_seconds` before now, awswrangler - skips query execution and just returns the same results as last time. - If cached results are valid, awswrangler ignores the `s3_output`, `encryption` and `kms_key` params. + athena_cache_settings: typing.AthenaCacheSettings, optional + Params of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections, + max_remote_cache_entries, and max_local_cache_entries. + If cached results are valid, awswrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`, + `keep_files` and `ctas_temp_table_name` params. If reading cached data fails for any reason, execution falls back to the usual query run path. - max_cache_query_inspections : int - Max number of queries that will be inspected from the history to try to find some result to reuse. - The bigger the number of inspection, the bigger will be the latency for not cached queries. - Only takes effect if max_cache_seconds > 0. - max_remote_cache_entries : int - Max number of queries that will be retrieved from AWS for cache inspection. - The bigger the number of inspection, the bigger will be the latency for not cached queries. - Only takes effect if max_cache_seconds > 0 and default value is 50. - max_local_cache_entries : int - Max number of queries for which metadata will be cached locally. This will reduce the latency and also - enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be - smaller than max_remote_cache_entries. - Only takes effect if max_cache_seconds > 0 and default value is 100. data_source : str, optional Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default. wait : bool, default False diff --git a/awswrangler/typing.py b/awswrangler/typing.py index 44d2ba851..36cad8c89 100644 --- a/awswrangler/typing.py +++ b/awswrangler/typing.py @@ -86,9 +86,30 @@ class AthenaCacheSettings(TypedDict): """Typed dictionary defining the settings for using cached Athena results.""" max_cache_seconds: NotRequired[int] + """ + awswrangler can look up in Athena's history if this table has been read before. + If so, and its completion time is less than `max_cache_seconds` before now, awswrangler + skips query execution and just returns the same results as last time. + """ max_cache_query_inspections: NotRequired[int] + """ + Max number of queries that will be inspected from the history to try to find some result to reuse. + The bigger the number of inspection, the bigger will be the latency for not cached queries. + Only takes effect if max_cache_seconds > 0. + """ max_remote_cache_entries: NotRequired[int] + """ + Max number of queries that will be retrieved from AWS for cache inspection. + The bigger the number of inspection, the bigger will be the latency for not cached queries. + Only takes effect if max_cache_seconds > 0 and default value is 50. + """ max_local_cache_entries: NotRequired[int] + """ + Max number of queries for which metadata will be cached locally. This will reduce the latency and also + enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be + smaller than max_remote_cache_entries. + Only takes effect if max_cache_seconds > 0 and default value is 100. + """ class _S3WriteDataReturnValue(TypedDict): From 8b7365c858bdee84cf9e3212a74611fa59f90b59 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 14 Feb 2023 11:35:03 -0600 Subject: [PATCH 12/18] Fix use of _urls_test --- tests/unit/test_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 85e067512..560580380 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -121,14 +121,14 @@ def test_basics( wr.config.athena_endpoint_url = f"https://athena.{region}.amazonaws.com" wr.config.glue_endpoint_url = f"https://glue.{region}.amazonaws.com" wr.config.secretsmanager_endpoint_url = f"https://secretsmanager.{region}.amazonaws.com" - _urls_test(glue_database) + _urls_test(wr, glue_database) os.environ["WR_STS_ENDPOINT_URL"] = f"https://sts.{region}.amazonaws.com" os.environ["WR_S3_ENDPOINT_URL"] = f"https://s3.{region}.amazonaws.com" os.environ["WR_ATHENA_ENDPOINT_URL"] = f"https://athena.{region}.amazonaws.com" os.environ["WR_GLUE_ENDPOINT_URL"] = f"https://glue.{region}.amazonaws.com" os.environ["WR_SECRETSMANAGER_ENDPOINT_URL"] = f"https://secretsmanager.{region}.amazonaws.com" wr.config.reset() - _urls_test(glue_database) + _urls_test(wr, glue_database) def test_athena_cache_configuration(wr: ModuleType) -> None: From 7fd5f1da3c4d130d6add8b7cccd1f0b247860017 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 14 Feb 2023 12:39:51 -0600 Subject: [PATCH 13/18] Fix test_cache_start_query data_source --- tests/unit/test_athena_cache.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_athena_cache.py b/tests/unit/test_athena_cache.py index f21dbc0d8..457461a5c 100644 --- a/tests/unit/test_athena_cache.py +++ b/tests/unit/test_athena_cache.py @@ -246,7 +246,9 @@ def test_cache_start_query(wr, path, glue_database, glue_table, data_source): "awswrangler.athena._utils._check_for_cached_results", return_value=wr.athena._read._CacheInfo(has_valid_cache=False), ) as mocked_cache_attempt: - query_id = wr.athena.start_query_execution(sql=f"SELECT * FROM {glue_table}", database=glue_database) + query_id = wr.athena.start_query_execution( + sql=f"SELECT * FROM {glue_table}", database=glue_database, data_source=data_source + ) mocked_cache_attempt.assert_called() # Wait for query to finish in order to successfully check cache @@ -254,7 +256,10 @@ def test_cache_start_query(wr, path, glue_database, glue_table, data_source): with patch("awswrangler.athena._utils._start_query_execution") as internal_start_query: query_id_2 = wr.athena.start_query_execution( - sql=f"SELECT * FROM {glue_table}", database=glue_database, athena_cache_settings={"max_cache_seconds": 900} + sql=f"SELECT * FROM {glue_table}", + database=glue_database, + data_source=data_source, + athena_cache_settings={"max_cache_seconds": 900}, ) internal_start_query.assert_not_called() assert query_id == query_id_2 From 76b7a042560028217846f65902036c643622f172 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 14 Feb 2023 16:38:00 -0600 Subject: [PATCH 14/18] Fix reset_item --- awswrangler/_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index d8102c54e..937a6ee1d 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -33,7 +33,7 @@ class _ConfigArg(NamedTuple): "concurrent_partitioning": _ConfigArg(dtype=bool, nullable=False), "ctas_approach": _ConfigArg(dtype=bool, nullable=False), "database": _ConfigArg(dtype=str, nullable=True), - "athena_cache_settings": _ConfigArg(dtype=dict, nullable=False, is_parent=True, default={}, loaded=True), + "athena_cache_settings": _ConfigArg(dtype=dict, nullable=False, is_parent=True, loaded=True), "max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), "max_cache_seconds": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), "max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"), @@ -211,7 +211,9 @@ def __getitem__(self, item: str) -> Optional[_ConfigValueType]: def _reset_item(self, item: str) -> None: if item in self._loaded_values: - if _CONFIG_ARGS[item].loaded: + if _CONFIG_ARGS[item].is_parent: + self._loaded_values[item] = {} + elif _CONFIG_ARGS[item].loaded: self._loaded_values[item] = _CONFIG_ARGS[item].default else: del self._loaded_values[item] From 22cd1f964ffce7cb6088a686c03c9f7579c455ee Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 14 Feb 2023 17:23:52 -0600 Subject: [PATCH 15/18] Fix reset of single parameter --- awswrangler/_config.py | 21 +++++++++++++++------ tests/unit/test_config.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 937a6ee1d..5c9a2c0f8 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -210,13 +210,22 @@ def __getitem__(self, item: str) -> Optional[_ConfigValueType]: return loaded_values[item] def _reset_item(self, item: str) -> None: - if item in self._loaded_values: - if _CONFIG_ARGS[item].is_parent: - self._loaded_values[item] = {} - elif _CONFIG_ARGS[item].loaded: - self._loaded_values[item] = _CONFIG_ARGS[item].default + config_arg = _CONFIG_ARGS[item] + loaded_values: Dict[str, Optional[_ConfigValueType]] + + if config_arg.parent_parameter_key: + loaded_values = self[config_arg.parent_parameter_key] # type: ignore[assignment] + else: + loaded_values = self._loaded_values + + if item in loaded_values: + if config_arg.is_parent: + loaded_values[item] = {} + elif config_arg.loaded: + loaded_values[item] = config_arg.default else: - del self._loaded_values[item] + del loaded_values[item] + self._load_config(name=item) def _repr_html_(self) -> Any: diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 560580380..1c227575e 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -131,6 +131,18 @@ def test_basics( _urls_test(wr, glue_database) +def test_config_reset_nested_value(wr: ModuleType) -> None: + wr.config.max_remote_cache_entries = 50 + wr.config.max_cache_seconds = 20 + + wr.config.reset("max_remote_cache_entries") + + assert wr.config.max_cache_seconds == 20 + + with pytest.raises(AttributeError): + wr.config.max_remote_cache_entries + + def test_athena_cache_configuration(wr: ModuleType) -> None: wr.config.max_remote_cache_entries = 50 wr.config.max_cache_seconds = 20 From 8e837a56b7759bffae9ebe39c8ba947e58074fcb Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 15 Feb 2023 10:20:43 -0600 Subject: [PATCH 16/18] Isolate test_config --- tests/unit/test_config.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 1c227575e..3d81ac479 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -9,9 +9,6 @@ import botocore.config import pytest -from awswrangler._config import apply_configs -from awswrangler.s3._fs import open_s3_object - logging.getLogger("awswrangler").setLevel(logging.DEBUG) @@ -67,9 +64,9 @@ def test_basics( # Testing configured s3 block size size = 1 * 2**20 # 1 MB wr.config.s3_block_size = size - with open_s3_object(path, mode="wb") as s3obj: + with wr.s3._fs.open_s3_object(path, mode="wb") as s3obj: s3obj.write(b"foo") - with open_s3_object(path, mode="rb") as s3obj: + with wr.s3._fs.open_s3_object(path, mode="rb") as s3obj: assert s3obj._s3_block_size == size # Resetting all configs @@ -183,7 +180,7 @@ def wrapper(self, **kwarg): # Check for default values with patch("botocore.client.ClientCreator.create_client", new=wrapper): - with open_s3_object(path, mode="wb") as s3obj: + with wr.s3._fs.open_s3_object(path, mode="wb") as s3obj: s3obj.write(b"foo") # Update default config with environment variables @@ -196,7 +193,7 @@ def wrapper(self, **kwarg): os.environ["AWS_RETRY_MODE"] = expected_retry_mode with patch("botocore.client.ClientCreator.create_client", new=wrapper): - with open_s3_object(path, mode="wb") as s3obj: + with wr.s3._fs.open_s3_object(path, mode="wb") as s3obj: s3obj.write(b"foo") del os.environ["AWS_MAX_ATTEMPTS"] @@ -216,7 +213,7 @@ def wrapper(self, **kwarg): wr.config.botocore_config = botocore_config with patch("botocore.client.ClientCreator.create_client", new=wrapper): - with open_s3_object(path, mode="wb") as s3obj: + with wr.s3._fs.open_s3_object(path, mode="wb") as s3obj: s3obj.write(b"foo") wr.config.reset() @@ -229,7 +226,7 @@ def test_chunk_size(wr: ModuleType) -> None: for function_to_mock in [wr.postgresql.to_sql, wr.mysql.to_sql, wr.sqlserver.to_sql, wr.redshift.to_sql]: mock = create_autospec(function_to_mock) - apply_configs(mock)(df=None, con=None, table=None, schema=None) + wr._config.apply_configs(mock)(df=None, con=None, table=None, schema=None) mock.assert_called_with(df=None, con=None, table=None, schema=None, chunksize=expected_chunksize) expected_chunksize = 456 @@ -238,5 +235,5 @@ def test_chunk_size(wr: ModuleType) -> None: for function_to_mock in [wr.postgresql.to_sql, wr.mysql.to_sql, wr.sqlserver.to_sql, wr.redshift.to_sql]: mock = create_autospec(function_to_mock) - apply_configs(mock)(df=None, con=None, table=None, schema=None) + wr._config.apply_configs(mock)(df=None, con=None, table=None, schema=None) mock.assert_called_with(df=None, con=None, table=None, schema=None, chunksize=expected_chunksize) From 56888b6b3daf415fc54aeca03042e021dbb53dfa Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 20 Feb 2023 12:24:56 -0600 Subject: [PATCH 17/18] Add more details to documentation --- awswrangler/athena/_read.py | 13 +++++++++++++ awswrangler/athena/_utils.py | 2 ++ 2 files changed, 15 insertions(+) diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 25f412631..2e1ba6867 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -1027,6 +1027,8 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals athena_cache_settings: typing.AthenaCacheSettings, optional Params of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections, max_remote_cache_entries, and max_local_cache_entries. + AthenaCacheSettings is a `TypedDict`, meaning the passed parameter can be instantiated either as an + instance of AthenaCacheSettings or as a regular Python dict. If cached results are valid, awswrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`, `keep_files` and `ctas_temp_table_name` params. If reading cached data fails for any reason, execution falls back to the usual query run path. @@ -1061,6 +1063,15 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals ... params={"name": "filtered_name", "city": "filtered_city"} ... ) + >>> import awswrangler as wr + >>> df = wr.athena.read_sql_query( + ... sql="...", + ... database="...", + ... athena_cache_settings={ + ... "max_cache_seconds": 90, + ... }, + ... ) + """ if ctas_approach and data_source not in (None, "AwsDataCatalog"): raise exceptions.InvalidArgumentCombination( @@ -1444,6 +1455,8 @@ def read_sql_table( athena_cache_settings: typing.AthenaCacheSettings, optional Params of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections, max_remote_cache_entries, and max_local_cache_entries. + AthenaCacheSettings is a `TypedDict`, meaning the passed parameter can be instantiated either as an + instance of AthenaCacheSettings or as a regular Python dict. If cached results are valid, awswrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`, `keep_files` and `ctas_temp_table_name` params. If reading cached data fails for any reason, execution falls back to the usual query run path. diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 933045852..b141f5a1e 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -480,6 +480,8 @@ def start_query_execution( athena_cache_settings: typing.AthenaCacheSettings, optional Params of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections, max_remote_cache_entries, and max_local_cache_entries. + AthenaCacheSettings is a `TypedDict`, meaning the passed parameter can be instantiated either as an + instance of AthenaCacheSettings or as a regular Python dict. If cached results are valid, awswrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`, `keep_files` and `ctas_temp_table_name` params. If reading cached data fails for any reason, execution falls back to the usual query run path. From daa8f0c993b38ac135a72204b91ebcd68d907897 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Tue, 21 Feb 2023 09:03:49 -0600 Subject: [PATCH 18/18] Refactor _load_config --- awswrangler/_config.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 5c9a2c0f8..da084368b 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -156,24 +156,18 @@ def to_pandas(self) -> pd.DataFrame: args.append(arg) return pd.DataFrame(args) - def _load_config(self, name: str) -> bool: + def _load_config(self, name: str) -> None: if _CONFIG_ARGS[name].is_parent: if self._loaded_values.get(name) is None: self._set_config_value(key=name, value={}) - return True - return False + return - loaded_config: bool = False if _CONFIG_ARGS[name].loaded: self._set_config_value(key=name, value=_CONFIG_ARGS[name].default) - loaded_config = True env_var: Optional[str] = os.getenv(f"WR_{name.upper()}") if env_var is not None: self._set_config_value(key=name, value=env_var) - loaded_config = True - - return loaded_config def _set_config_value(self, key: str, value: Any) -> None: if key not in _CONFIG_ARGS: