From 426b5513c4e728d4f99340abf516f89e88f748b4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 22 Jul 2024 18:20:00 +0200 Subject: [PATCH 1/4] rank --- src/litdata/processing/data_processor.py | 2 +- src/litdata/utilities/broadcast.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index ba1007941..b5bb0fade 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -934,7 +934,7 @@ def __init__( raise ValueError("Either the reader or the weights needs to be defined.") # Ensure the input dir is the same across all nodes - self.input_dir = broadcast_object("input_dir", self.input_dir) + self.input_dir = broadcast_object("input_dir", self.input_dir, rank=_get_node_rank()) if self.output_dir: # Ensure the output dir is the same across all nodes diff --git a/src/litdata/utilities/broadcast.py b/src/litdata/utilities/broadcast.py index feb87f099..770aabb36 100644 --- a/src/litdata/utilities/broadcast.py +++ b/src/litdata/utilities/broadcast.py @@ -123,11 +123,8 @@ def __init__(self) -> None: self.private_client: _HTTPClient = _HTTPClient(lightning_app_state_url, auth_token=token, use_retry=False) - def set_and_get(self, key: str, value: Any, rank: Optional[int] = None) -> Any: - payload = {"key": key, "value": pickle.dumps(value, 0).decode()} - - if rank is not None: - payload["rank"] = str(rank) + def set_and_get(self, key: str, value: Any, rank: int) -> Any: + payload = {"key": key, "value": pickle.dumps(value, 0).decode(), "rank": str(rank)} # Try the public address first try: @@ -145,7 +142,7 @@ def set_and_get(self, key: str, value: Any, rank: Optional[int] = None) -> Any: return pickle.loads(bytes(value, "utf-8")) # noqa: S301 -def broadcast_object(key: str, obj: Any, rank: Optional[int] = None) -> Any: +def broadcast_object(key: str, obj: Any, rank: int) -> Any: """This function enables to broadcast object across machines.""" if os.getenv("LIGHTNING_APP_EXTERNAL_URL") is not None: value = None From 373ddff38a44e3d806d52e8139acaef162e02ba5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 22 Jul 2024 19:43:17 +0200 Subject: [PATCH 2/4] debug --- src/litdata/processing/data_processor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index b5bb0fade..23d394620 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -938,7 +938,10 @@ def __init__( if self.output_dir: # Ensure the output dir is the same across all nodes - self.output_dir = broadcast_object("output_dir", self.output_dir, rank=_get_node_rank()) + print(f"Sending request for output_dir, {self.output_dir=}, rank=", _get_node_rank()) + result = broadcast_object("output_dir", self.output_dir, rank=_get_node_rank()) + print(f"Broadcast result was", result, "rank=", _get_node_rank()) + self.output_dir = result print(f"Storing the files under {self.output_dir.path if self.output_dir.path else self.output_dir.url}") self.random_seed = random_seed From fa3485480436b00e0b8c0d63751f249834082e96 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 22 Jul 2024 19:43:43 +0200 Subject: [PATCH 3/4] Revert "debug" This reverts commit 373ddff38a44e3d806d52e8139acaef162e02ba5. --- src/litdata/processing/data_processor.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 23d394620..b5bb0fade 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -938,10 +938,7 @@ def __init__( if self.output_dir: # Ensure the output dir is the same across all nodes - print(f"Sending request for output_dir, {self.output_dir=}, rank=", _get_node_rank()) - result = broadcast_object("output_dir", self.output_dir, rank=_get_node_rank()) - print(f"Broadcast result was", result, "rank=", _get_node_rank()) - self.output_dir = result + self.output_dir = broadcast_object("output_dir", self.output_dir, rank=_get_node_rank()) print(f"Storing the files under {self.output_dir.path if self.output_dir.path else self.output_dir.url}") self.random_seed = random_seed From 8f2ae0420d7a8d30fd1a8aa03db23b17f9536218 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 22 Jul 2024 19:45:32 +0200 Subject: [PATCH 4/4] rank --- tests/utilities/test_broadcast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_broadcast.py b/tests/utilities/test_broadcast.py index 65a14b507..75c0c24f0 100644 --- a/tests/utilities/test_broadcast.py +++ b/tests/utilities/test_broadcast.py @@ -20,7 +20,7 @@ def fn(*args, **kwargs): resp.json = fn session.post.return_value = resp monkeypatch.setattr(requests, "Session", mock.MagicMock(return_value=session)) - assert broadcast_object("key", "value") == "value" + assert broadcast_object("key", "value", rank=0) == "value" @mock.patch.dict(