diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index ba1007941..b5bb0fade 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -934,7 +934,7 @@ def __init__( raise ValueError("Either the reader or the weights needs to be defined.") # Ensure the input dir is the same across all nodes - self.input_dir = broadcast_object("input_dir", self.input_dir) + self.input_dir = broadcast_object("input_dir", self.input_dir, rank=_get_node_rank()) if self.output_dir: # Ensure the output dir is the same across all nodes diff --git a/src/litdata/utilities/broadcast.py b/src/litdata/utilities/broadcast.py index feb87f099..770aabb36 100644 --- a/src/litdata/utilities/broadcast.py +++ b/src/litdata/utilities/broadcast.py @@ -123,11 +123,8 @@ def __init__(self) -> None: self.private_client: _HTTPClient = _HTTPClient(lightning_app_state_url, auth_token=token, use_retry=False) - def set_and_get(self, key: str, value: Any, rank: Optional[int] = None) -> Any: - payload = {"key": key, "value": pickle.dumps(value, 0).decode()} - - if rank is not None: - payload["rank"] = str(rank) + def set_and_get(self, key: str, value: Any, rank: int) -> Any: + payload = {"key": key, "value": pickle.dumps(value, 0).decode(), "rank": str(rank)} # Try the public address first try: @@ -145,7 +142,7 @@ def set_and_get(self, key: str, value: Any, rank: Optional[int] = None) -> Any: return pickle.loads(bytes(value, "utf-8")) # noqa: S301 -def broadcast_object(key: str, obj: Any, rank: Optional[int] = None) -> Any: +def broadcast_object(key: str, obj: Any, rank: int) -> Any: """This function enables to broadcast object across machines.""" if os.getenv("LIGHTNING_APP_EXTERNAL_URL") is not None: value = None diff --git a/tests/utilities/test_broadcast.py b/tests/utilities/test_broadcast.py index 65a14b507..75c0c24f0 100644 --- a/tests/utilities/test_broadcast.py +++ b/tests/utilities/test_broadcast.py @@ -20,7 +20,7 @@ def fn(*args, **kwargs): resp.json = fn session.post.return_value = resp monkeypatch.setattr(requests, "Session", mock.MagicMock(return_value=session)) - assert broadcast_object("key", "value") == "value" + assert broadcast_object("key", "value", rank=0) == "value" @mock.patch.dict(