diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index d27d702324520..00acb278e836a 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -11,7 +11,7 @@ from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.ddp import DDPStrategy from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher -from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13 from tests_fabric.helpers.runif import RunIf @@ -231,9 +231,7 @@ def _test_distributed_collectives_fn(strategy, collective): @skip_distributed_unavailable -@pytest.mark.parametrize("n", [1, 2]) -@RunIf(skip_windows=True) -@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13 +@pytest.mark.parametrize("n", [1, pytest.param(2, marks=pytest.mark.xfail(raises=TimeoutError, strict=False))]) def test_collectives_distributed(n): collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n) @@ -268,8 +266,8 @@ def _test_two_groups(strategy, left_collective, right_collective): @skip_distributed_unavailable -@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception") # Todo -@pytest.mark.xfail(strict=False, reason="TODO(carmocca): causing hangs in CI") +@RunIf(skip_windows=True) # unhandled timeouts +@pytest.mark.xfail(raises=TimeoutError, strict=False) def test_two_groups(): collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2) @@ -285,8 +283,7 @@ def _test_default_process_group(strategy, *collectives): @skip_distributed_unavailable -@RunIf(skip_windows=True) -@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13 +@RunIf(skip_windows=True) # unhandled timeouts def test_default_process_group(): collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)