diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 7276a87e..1005c503 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -60,6 +60,10 @@ def detect(cls) -> "_DistributedEnv": if world_size is None or world_size == -1: world_size = 1 + world_size = int(os.environ.get("WORLD_SIZE", world_size)) + global_rank = int(os.environ.get("GLOBAL_RANK", global_rank)) + num_nodes = int(os.environ.get("NNODES", num_nodes)) + return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes) @classmethod diff --git a/tests/utilities/test_env.py b/tests/utilities/test_env.py new file mode 100644 index 00000000..064d9380 --- /dev/null +++ b/tests/utilities/test_env.py @@ -0,0 +1,12 @@ +from litdata.utilities.env import _DistributedEnv + + +def test_distributed_env_from_env(monkeypatch): + monkeypatch.setenv("WORLD_SIZE", 2) + monkeypatch.setenv("GLOBAL_RANK", 1) + monkeypatch.setenv("NNODES", 2) + + dist_env = _DistributedEnv.detect() + assert dist_env.world_size == 2 + assert dist_env.global_rank == 1 + assert dist_env.num_nodes == 2