Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies.ddp import DDPStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13

from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -231,9 +231,7 @@ def _test_distributed_collectives_fn(strategy, collective):


@skip_distributed_unavailable
@pytest.mark.parametrize("n", [1, 2])
@RunIf(skip_windows=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13
@pytest.mark.parametrize("n", [1, pytest.param(2, marks=pytest.mark.xfail(raises=TimeoutError, strict=False))])
def test_collectives_distributed(n):
collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n)

Expand Down Expand Up @@ -268,8 +266,8 @@ def _test_two_groups(strategy, left_collective, right_collective):


@skip_distributed_unavailable
@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception") # Todo
@pytest.mark.xfail(strict=False, reason="TODO(carmocca): causing hangs in CI")
@RunIf(skip_windows=True) # unhandled timeouts
@pytest.mark.xfail(raises=TimeoutError, strict=False)
def test_two_groups():
collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)

Expand All @@ -285,8 +283,7 @@ def _test_default_process_group(strategy, *collectives):


@skip_distributed_unavailable
@RunIf(skip_windows=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13
@RunIf(skip_windows=True) # unhandled timeouts
def test_default_process_group():
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)

Expand Down