diff --git a/tests/unit/test_distributed.py b/tests/unit/test_distributed.py index 2eee67c3b..929b1a1a5 100644 --- a/tests/unit/test_distributed.py +++ b/tests/unit/test_distributed.py @@ -23,7 +23,8 @@ def wr() -> Iterator[ModuleType]: yield reload(awswrangler) # Reset for future tests - awswrangler.engine.register() + awswrangler.engine.set(awswrangler.engine.get_installed().value) + awswrangler.memory_format.set(awswrangler.memory_format.get_installed().value) @pytest.mark.skipif(condition=not is_ray_modin, reason="ray not available") @@ -71,22 +72,20 @@ def test_engine_python_without_ray_installed(wr: ModuleType) -> None: @pytest.mark.skipif(condition=not is_ray_modin, reason="ray not available") -def test_engine_switch() -> None: +def test_engine_switch(wr: ModuleType) -> None: from modin.pandas import DataFrame as ModinDataFrame from pandas import DataFrame as PandasDataFrame - import awswrangler as wr2 + assert wr.engine.get_installed() == wr.EngineEnum.RAY + assert wr.memory_format.get_installed() == wr.MemoryFormatEnum.MODIN - assert wr2.engine.get_installed() == wr2.EngineEnum.RAY - assert wr2.memory_format.get_installed() == wr2.MemoryFormatEnum.MODIN + assert wr.engine.get() == wr.EngineEnum.RAY + assert wr.memory_format.get() == wr.MemoryFormatEnum.MODIN + assert wr.pandas.DataFrame == ModinDataFrame - assert wr2.engine.get() == wr2.EngineEnum.RAY - assert wr2.memory_format.get() == wr2.MemoryFormatEnum.MODIN - assert wr2.pandas.DataFrame == ModinDataFrame + wr.engine.set("python") + wr.memory_format.set("pandas") - wr2.engine.set("python") - wr2.memory_format.set("pandas") - - assert wr2.engine.get() == wr2.EngineEnum.PYTHON - assert wr2.memory_format.get() == wr2.MemoryFormatEnum.PANDAS - assert wr2.pandas.DataFrame == PandasDataFrame + assert wr.engine.get() == wr.EngineEnum.PYTHON + assert wr.memory_format.get() == wr.MemoryFormatEnum.PANDAS + assert wr.pandas.DataFrame == PandasDataFrame