Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,54 @@ def broadcast_scalar(
return tensor.item()


def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None, src: int = 0) -> typing.Any:
"""
Broadcasts a Python object from src rank to all other ranks in the ProcessGroup.
Returns the object on all ranks.
"""
assert group is not None

if group.rank() == src:
tensor = _object_to_tensor(input_object)
size = tensor.numel()
broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
broadcast_tensor.copy_(tensor)
broadcast_scalar(size, torch.int64, group, src)
broadcast(broadcast_tensor, src, group)
return input_object
else:
size = int(broadcast_scalar(None, torch.int64, group, src))
output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
broadcast(output_tensor, src, group)
return _tensor_to_object(output_tensor)


def broadcast_optional(tensor: torch.Tensor | None, group: ProcessGroup = None, src: int = 0) -> torch.Tensor:
"""
Broadcasts an optional tensor of size, shape, and dtype unknown in advance.
Returns the tensor on all ranks or None if no tensor was sent.
"""
assert group is not None

if group.rank() == src:
has_tensor = tensor is not None
if has_tensor:
meta = (has_tensor, tensor.shape, tensor.dtype)
else:
meta = (has_tensor, None, None)
broadcast_object(meta, group, src)
if has_tensor:
broadcast(tensor.to(torch.cuda.current_device()), src, group)
return tensor
else:
has_tensor, shape, dtype = broadcast_object(None, group, src)
if not has_tensor:
return None
output_tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device())
broadcast(output_tensor, src, group)
return output_tensor


def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
assert group is not None
work = group.send([tensor], dst, tag)
Expand Down Expand Up @@ -186,7 +234,11 @@ def scatter(
def _object_to_tensor(obj: typing.Any) -> torch.Tensor:
f = io.BytesIO()
pickle.Pickler(f).dump(obj)
return torch.tensor(torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8), dtype=torch.uint8)
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
return torch.ByteTensor(byte_storage)


def _tensor_to_object(tensor: torch.Tensor) -> typing.Any:
Expand Down
7 changes: 7 additions & 0 deletions fast_llm/engine/evaluation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ class LmEvalEvaluatorConfig(EvaluatorConfig):
" If not set, it is inferred from the Fast-LLM model config or tokenizer.",
)

communication_timeout_sec: float = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timeout. Unnecessary long timeouts are often bad, so I recommend making it optional (default none) and enabling only as needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context

Conceptually, places like worker_forward or data-parallel_worker wait primitives should only exit under three conditions:

  1. They receive work
  2. They receive a finish message
  3. The connection with peers/coordinator is lost (after some timeout)

However, this is not how torch.distributed works. It is designed for more or less synchronous communication, while here we are trying to adapt it for asynchronous communication.

Problem

If we set the default timeout to None, users will end up seeing random timeouts in different places.

Discussion

A better long-term solution would be to use a distributed messaging framework that is more appropriate for sending work and finish messages. However, introducing another communication layer into fast_llm is likely outside the scope of this PR.

Proposal

  • Keep the default timeout as it is, applied only to these entry points. reset timeout after wait operation to default of 60 sec.
  • Clarify the naming/description to avoid confusion.
  • Add a TODO to revisit this later with a more suitable communication framework.

default=600.0,
desc="Maximum wait time (in seconds) for tensor-parallel or data-parallel model "
"operations such as forward, generate, or gathering data. Needed because some "
"ranks may have no data or post-processing can be slow, exceeding the default 60s timeout.",
)

def get_evaluator(
self,
name: str,
Expand Down
2 changes: 2 additions & 0 deletions fast_llm/engine/evaluation/lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def setup(
add_bos_token=self._config.add_bos_token,
prefix_token_id=self._config.prefix_token_id,
max_length=self._config.max_length,
batch_config=self._batch_config,
communication_timeout_sec=self._config.communication_timeout_sec,
)
self._is_setup = True

Expand Down
Loading