diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 1005c503a..3c23e816e 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -49,8 +49,14 @@ def detect(cls) -> "_DistributedEnv": world_size = torch.distributed.get_world_size() global_rank = torch.distributed.get_rank() # Note: On multi node CPU, the number of nodes won't be correct. - num_nodes = world_size // torch.cuda.device_count() if torch.cuda.is_available() else world_size - if torch.cuda.is_available() and world_size % torch.cuda.device_count() != 0: + if torch.cuda.is_available() and world_size // torch.cuda.device_count() >= 1: + num_nodes = world_size // torch.cuda.device_count() + else: + num_nodes = 1 + + # If you are using multiple nodes, we assume you are using all the GPUs. + # On single node, a user can be using only a few GPUs of the node. + if torch.cuda.is_available() and num_nodes > 1 and world_size % torch.cuda.device_count() != 0: raise RuntimeError("The world size should be divisible by the number of GPUs.") else: world_size = None diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 278b4872b..24b27ff41 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -255,7 +255,12 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir, compr "compression", [ pytest.param(None), - pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")), + pytest.param( + "zstd", + marks=pytest.mark.skipif( + condition=not _ZSTD_AVAILABLE or sys.platform == "darwin", reason="Requires: ['zstd']" + ), + ), ], ) @pytest.mark.timeout(30)