From 456cc3fb5f4fbeb786d275e4967ac61d5d96e9b0 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 31 May 2023 15:38:49 -0500 Subject: [PATCH 1/2] fix: Engine initialization did not respect OS env --- awswrangler/_distributed.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/awswrangler/_distributed.py b/awswrangler/_distributed.py index 7c704fe38..93deab550 100644 --- a/awswrangler/_distributed.py +++ b/awswrangler/_distributed.py @@ -11,8 +11,14 @@ from importlib import reload from typing import Any, Callable, Dict, Literal, Optional, TypeVar, cast -WR_ENGINE = os.getenv("WR_ENGINE") -WR_MEMORY_FORMAT = os.getenv("WR_MEMORY_FORMAT") +EngineLiteral = Literal["python", "ray"] +MemoryFormatLiteral = Literal["pandas", "modin"] + +FunctionType = TypeVar("FunctionType", bound=Callable[..., Any]) + + +WR_ENGINE: Optional[EngineLiteral] = os.getenv("WR_ENGINE") # type: ignore[assignment] +WR_MEMORY_FORMAT: Optional[MemoryFormatLiteral] = os.getenv("WR_MEMORY_FORMAT") # type: ignore[assignment] @unique @@ -31,11 +37,6 @@ class MemoryFormatEnum(Enum): PANDAS = "pandas" -EngineLiteral = Literal["python", "ray"] -MemoryFormatLiteral = Literal["pandas", "modin"] -FunctionType = TypeVar("FunctionType", bound=Callable[..., Any]) - - class Engine: """Execution engine configuration class.""" @@ -111,7 +112,7 @@ def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any: def register(cls, name: Optional[EngineLiteral] = None) -> None: """Register the distribution engine dispatch methods.""" with cls._lock: - engine_name = cast(EngineLiteral, name or cls.get_installed().value) + engine_name = cast(EngineLiteral, name or WR_ENGINE or cls.get_installed().value) cls.set(engine_name) cls._registry.clear() From fbb7947a47f4e7a0af42d9f5eb938e9505216e7e Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Thu, 1 Jun 2023 09:16:23 -0500 Subject: [PATCH 2/2] fix register --- awswrangler/_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awswrangler/_distributed.py b/awswrangler/_distributed.py index 93deab550..36f1bba9d 100644 --- a/awswrangler/_distributed.py +++ b/awswrangler/_distributed.py @@ -112,7 +112,7 @@ def wrapper(*args: Any, **kw: Dict[str, Any]) -> Any: def register(cls, name: Optional[EngineLiteral] = None) -> None: """Register the distribution engine dispatch methods.""" with cls._lock: - engine_name = cast(EngineLiteral, name or WR_ENGINE or cls.get_installed().value) + engine_name = cast(EngineLiteral, name or cls.get().value) cls.set(engine_name) cls._registry.clear()