Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 89 additions & 32 deletions awswrangler/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@
import pandas as pd

from awswrangler import exceptions
from awswrangler.typing import AthenaCacheSettings

_logger: logging.Logger = logging.getLogger(__name__)


_ConfigValueType = Union[str, bool, int, botocore.config.Config, None]
_ConfigValueType = Union[str, bool, int, botocore.config.Config, dict]


class _ConfigArg(NamedTuple):
dtype: Type[Union[str, bool, int, botocore.config.Config]]
dtype: Type[_ConfigValueType]
nullable: bool
enforced: bool = False
loaded: bool = False
default: Optional[_ConfigValueType] = None
parent_parameter_key: Optional[str] = None
is_parent: bool = False


# Please, also add any new argument as a property in the _Config class
Expand All @@ -30,10 +33,11 @@ class _ConfigArg(NamedTuple):
"concurrent_partitioning": _ConfigArg(dtype=bool, nullable=False),
"ctas_approach": _ConfigArg(dtype=bool, nullable=False),
"database": _ConfigArg(dtype=str, nullable=True),
"max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False),
"max_cache_seconds": _ConfigArg(dtype=int, nullable=False),
"max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False),
"max_local_cache_entries": _ConfigArg(dtype=int, nullable=False),
"athena_cache_settings": _ConfigArg(dtype=dict, nullable=False, is_parent=True, loaded=True),
"max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"),
"max_cache_seconds": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"),
"max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"),
"max_local_cache_entries": _ConfigArg(dtype=int, nullable=False, parent_parameter_key="athena_cache_settings"),
"s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
"workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True),
"chunksize": _ConfigArg(dtype=int, nullable=False, enforced=True),
Expand Down Expand Up @@ -71,7 +75,7 @@ class _Config: # pylint: disable=too-many-instance-attributes,too-many-public-m
"""AWS Wrangler's Configuration class."""

def __init__(self) -> None:
self._loaded_values: Dict[str, _ConfigValueType] = {}
self._loaded_values: Dict[str, Optional[_ConfigValueType]] = {}
name: str
self.s3_endpoint_url = None
self.athena_endpoint_url = None
Expand Down Expand Up @@ -132,12 +136,16 @@ def to_pandas(self) -> pd.DataFrame:
"""
args: List[Dict[str, Any]] = []
for k, v in _CONFIG_ARGS.items():
if v.is_parent:
continue

arg: Dict[str, Any] = {
"name": k,
"Env. Variable": f"WR_{k.upper()}",
"Env.Variable": f"WR_{k.upper()}",
"type": v.dtype,
"nullable": v.nullable,
"enforced": v.enforced,
"parent_parameter_name": v.parent_parameter_key,
}
if k in self._loaded_values:
arg["configured"] = True
Expand All @@ -148,48 +156,77 @@ def to_pandas(self) -> pd.DataFrame:
args.append(arg)
return pd.DataFrame(args)

def _load_config(self, name: str) -> bool:
loaded_config: bool = False
def _load_config(self, name: str) -> None:
if _CONFIG_ARGS[name].is_parent:
if self._loaded_values.get(name) is None:
self._set_config_value(key=name, value={})
return

if _CONFIG_ARGS[name].loaded:
self._set_config_value(key=name, value=_CONFIG_ARGS[name].default)
loaded_config = True

env_var: Optional[str] = os.getenv(f"WR_{name.upper()}")
if env_var is not None:
self._set_config_value(key=name, value=env_var)
loaded_config = True
return loaded_config

def _set_config_value(self, key: str, value: Any) -> None:
if key not in _CONFIG_ARGS:
raise exceptions.InvalidArgumentValue(
f"{key} is not a valid configuration. Please use: {list(_CONFIG_ARGS.keys())}"
)
value_casted: _ConfigValueType = self._apply_type(
value_casted = self._apply_type(
name=key,
value=value,
dtype=_CONFIG_ARGS[key].dtype, # type: ignore[arg-type]
dtype=_CONFIG_ARGS[key].dtype,
nullable=_CONFIG_ARGS[key].nullable,
)
self._loaded_values[key] = value_casted

def __getitem__(self, item: str) -> _ConfigValueType:
if item not in self._loaded_values:
parent_key = _CONFIG_ARGS[key].parent_parameter_key
if parent_key:
self._loaded_values[parent_key][key] = value_casted # type: ignore[index]
else:
self._loaded_values[key] = value_casted

def __getitem__(self, item: str) -> Optional[_ConfigValueType]:
if issubclass(_CONFIG_ARGS[item].dtype, dict):
return self._loaded_values[item]

loaded_values: Dict[str, Optional[_ConfigValueType]]
parent_key = _CONFIG_ARGS[item].parent_parameter_key
if parent_key:
loaded_values = self[parent_key] # type: ignore[assignment]
else:
loaded_values = self._loaded_values

if item not in loaded_values:
raise AttributeError(f"{item} not configured yet.")
return self._loaded_values[item]

return loaded_values[item]

def _reset_item(self, item: str) -> None:
if item in self._loaded_values:
if _CONFIG_ARGS[item].loaded:
self._loaded_values[item] = _CONFIG_ARGS[item].default
config_arg = _CONFIG_ARGS[item]
loaded_values: Dict[str, Optional[_ConfigValueType]]

if config_arg.parent_parameter_key:
loaded_values = self[config_arg.parent_parameter_key] # type: ignore[assignment]
else:
loaded_values = self._loaded_values

if item in loaded_values:
if config_arg.is_parent:
loaded_values[item] = {}
elif config_arg.loaded:
loaded_values[item] = config_arg.default
else:
del self._loaded_values[item]
del loaded_values[item]

self._load_config(name=item)

def _repr_html_(self) -> Any:
return self.to_pandas().to_html()

@staticmethod
def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nullable: bool) -> _ConfigValueType:
def _apply_type(name: str, value: Any, dtype: Type[_ConfigValueType], nullable: bool) -> Optional[_ConfigValueType]:
if _Config._is_null(value=value):
if nullable is True:
return None
Expand All @@ -202,7 +239,7 @@ def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nulla
raise exceptions.InvalidConfiguration(f"Config {name} must receive a {dtype} value.") from ex

@staticmethod
def _is_null(value: _ConfigValueType) -> bool:
def _is_null(value: Optional[_ConfigValueType]) -> bool:
if value is None:
return True
if isinstance(value, str) is True:
Expand Down Expand Up @@ -247,6 +284,11 @@ def database(self) -> Optional[str]:
def database(self, value: Optional[str]) -> None:
self._set_config_value(key="database", value=value)

@property
def athena_cache_settings(self) -> AthenaCacheSettings:
"""Property athena_cache_settings."""
return cast(AthenaCacheSettings, self["athena_cache_settings"])

@property
def max_cache_query_inspections(self) -> int:
"""Property max_cache_query_inspections."""
Expand Down Expand Up @@ -576,6 +618,25 @@ def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) -
return _insert_str(text=doc, token="\n Parameters", insert=insertion)


def _assign_args_value(args: Dict[str, Any], name: str, value: Any) -> None:
if _CONFIG_ARGS[name].is_parent:
if name not in args:
args[name] = {}

nested_args = cast(Dict[str, Any], value)
for nested_arg_name, nested_arg_value in nested_args.items():
_assign_args_value(args[name], nested_arg_name, nested_arg_value)
return

if name not in args:
_logger.debug("Applying default config argument %s with value %s.", name, value)
args[name] = value

elif _CONFIG_ARGS[name].enforced is True:
_logger.debug("Applying ENFORCED config argument %s with value %s.", name, value)
args[name] = value


FunctionType = TypeVar("FunctionType", bound=Callable[..., Any])


Expand All @@ -589,13 +650,9 @@ def wrapper(*args_raw: Any, **kwargs: Any) -> Any:
args: Dict[str, Any] = signature.bind_partial(*args_raw, **kwargs).arguments
for name in available_configs:
if hasattr(config, name) is True:
value: _ConfigValueType = config[name]
if name not in args:
_logger.debug("Applying default config argument %s with value %s.", name, value)
args[name] = value
elif _CONFIG_ARGS[name].enforced is True:
_logger.debug("Applying ENFORCED config argument %s with value %s.", name, value)
args[name] = value
value = config[name]
_assign_args_value(args, name, value)

for name, param in signature.parameters.items():
if param.kind == param.VAR_KEYWORD and name in args:
if isinstance(args[name], dict) is False:
Expand Down
Loading