Skip to content

Commit ce0240c

Browse files
authored
Always send the rank when broadcasting (#257)
* rank * debug * Revert "debug" This reverts commit 373ddff. * rank
1 parent 09129c1 commit ce0240c

File tree

3 files changed

+5
-8
lines changed

3 files changed

+5
-8
lines changed

src/litdata/processing/data_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ def __init__(
934934
raise ValueError("Either the reader or the weights needs to be defined.")
935935

936936
# Ensure the input dir is the same across all nodes
937-
self.input_dir = broadcast_object("input_dir", self.input_dir)
937+
self.input_dir = broadcast_object("input_dir", self.input_dir, rank=_get_node_rank())
938938

939939
if self.output_dir:
940940
# Ensure the output dir is the same across all nodes

src/litdata/utilities/broadcast.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,8 @@ def __init__(self) -> None:
123123

124124
self.private_client: _HTTPClient = _HTTPClient(lightning_app_state_url, auth_token=token, use_retry=False)
125125

126-
def set_and_get(self, key: str, value: Any, rank: Optional[int] = None) -> Any:
127-
payload = {"key": key, "value": pickle.dumps(value, 0).decode()}
128-
129-
if rank is not None:
130-
payload["rank"] = str(rank)
126+
def set_and_get(self, key: str, value: Any, rank: int) -> Any:
127+
payload = {"key": key, "value": pickle.dumps(value, 0).decode(), "rank": str(rank)}
131128

132129
# Try the public address first
133130
try:
@@ -145,7 +142,7 @@ def set_and_get(self, key: str, value: Any, rank: Optional[int] = None) -> Any:
145142
return pickle.loads(bytes(value, "utf-8")) # noqa: S301
146143

147144

148-
def broadcast_object(key: str, obj: Any, rank: Optional[int] = None) -> Any:
145+
def broadcast_object(key: str, obj: Any, rank: int) -> Any:
149146
"""This function enables to broadcast object across machines."""
150147
if os.getenv("LIGHTNING_APP_EXTERNAL_URL") is not None:
151148
value = None

tests/utilities/test_broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def fn(*args, **kwargs):
2020
resp.json = fn
2121
session.post.return_value = resp
2222
monkeypatch.setattr(requests, "Session", mock.MagicMock(return_value=session))
23-
assert broadcast_object("key", "value") == "value"
23+
assert broadcast_object("key", "value", rank=0) == "value"
2424

2525

2626
@mock.patch.dict(

0 commit comments

Comments
 (0)