From 1b60413e01b510e9e701609027210801a03882fe Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 30 Jul 2025 11:36:57 +0000 Subject: [PATCH 1/9] added model and sequence parallel to forward --- fast_llm/engine/inference/huggingface.py | 5 ---- fast_llm/layers/language_model/head.py | 32 +++++++++++++----------- fast_llm/models/gpt/huggingface.py | 7 +++++- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 54a82492b..501798eed 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -13,7 +13,6 @@ from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -57,10 +56,6 @@ def __init__( # or set existing model which also must be setup, so, do not accept not setup model assert fast_llm_model.is_setup - # We only support data parallel for now - Assert.eq(fast_llm_model.distributed.config.model_parallel, 1) - Assert.eq(fast_llm_model.distributed.config.sequence_data_parallel, 1) - self._inference_runner.setup() # Transformers needs to be able to inspect the base model. diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..cb32673ba 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -267,6 +267,9 @@ def _logits_cross_entropy_forward_backward_split( ) if targets is None: # TODO: Make a proper way of returning the model output. + if "global_logits" in kwargs and kwargs["global_logits"] == True: + logits_meta = self._get_logits_meta(kwargs, loss) + loss, _ = logits_meta.local_to_global(loss.detach(), distributed=self._tensor_space.distributed) kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss return None, None else: @@ -306,6 +309,19 @@ def _logits_cross_entropy_forward_backward_split( all_reduce(loss, group=self._tensor_space.distributed.tensor_group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None + def _get_logits_meta(self, kwargs: dict, logits: torch.Tensor, logits_name: str = "transformer logits"): + vocab_dim = self._tensor_space.get_tensor_dim( + LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp + ) + dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] + sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) + dims[sequence_index] = ( + TensorDim(TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor) + if self._sequence_parallel_logits + else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) + ) + return TensorMeta.from_dims(tuple(dims), tensor_name=logits_name, dtype=logits.dtype) + def _logits_cross_entropy_forward_backward( self, input_: torch.Tensor, @@ -334,19 +350,7 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( - LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) - dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) - dims[sequence_index] = ( - TensorDim( - TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor - ) - if self._sequence_parallel_logits - else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) - ) - + # TODO: what for we are creating dim names list here? dim_names = ( [TransformerDimNames.sequence_q_tp, LanguageModelDimNames.vocab] if self._sequence_parallel_logits @@ -358,7 +362,7 @@ def _logits_cross_entropy_forward_backward( "", logits, level=self._debug_transformer, - meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), + meta=self._get_logits_meta(kwargs, logits), distributed=self._tensor_space.distributed, scale=self._logits_scale_factor, ) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index cf7da3872..d8100c4e4 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -99,10 +99,15 @@ def forward( else: kwargs["output_hidden_states"] = False + kwargs["global_logits"] = True + self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. - logits = kwargs["logits"] + if kwargs[TransformerKwargs.sequence_first]: + logits = kwargs["logits"].transpose(0, 1) + else: + logits = kwargs["logits"] # TODO: convert hidden state form dict to list to be the same as with HFs hidden_states = None From 5eac62131bef155c9f762379c41a967b498f739a Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 31 Jul 2025 14:57:48 +0000 Subject: [PATCH 2/9] added asserts for pipeline and sequence parallel to be 1 as not supported in forward --- fast_llm/engine/inference/huggingface.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 501798eed..df9361c8a 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -13,6 +13,7 @@ from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -58,6 +59,10 @@ def __init__( self._inference_runner.setup() + # We only support data parallel and tensor parallel for now + Assert.eq(fast_llm_model.distributed.config.pipeline_parallel, 1) + Assert.eq(fast_llm_model.distributed.config.sequence_data_parallel, 1) + # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = fast_llm_model.base_model From d882d7b3ef51b9d68adb136fc563dcd69a4ef553 Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 7 Aug 2025 12:19:41 +0000 Subject: [PATCH 3/9] changed logits gathering for only tp and stp dimensions --- fast_llm/layers/language_model/head.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index cb32673ba..3fc5c7965 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -6,7 +6,7 @@ from torch.distributed import all_reduce from fast_llm.config import Configurable -from fast_llm.core.ops import split_op +from fast_llm.core.ops import gather_op, split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames @@ -267,9 +267,17 @@ def _logits_cross_entropy_forward_backward_split( ) if targets is None: # TODO: Make a proper way of returning the model output. - if "global_logits" in kwargs and kwargs["global_logits"] == True: - logits_meta = self._get_logits_meta(kwargs, loss) - loss, _ = logits_meta.local_to_global(loss.detach(), distributed=self._tensor_space.distributed) + loss = loss.detach() + if kwargs.get("global_logits"): + dims, sequence_index = self._get_logits_dims(kwargs, loss) + # Only gather along the sequence dimension (for sequence tensor parallelism) + # or the vocabulary dimension (for tensor parallelism), if applicable. + for i in (sequence_index, len(dims) - 1): + dim = dims[i] + if dim.parallel_group is not None: + loss = gather_op( + loss.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index + ).flatten(i, i + len(dim.expanded_shape) - 1) kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss return None, None else: @@ -309,7 +317,7 @@ def _logits_cross_entropy_forward_backward_split( all_reduce(loss, group=self._tensor_space.distributed.tensor_group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None - def _get_logits_meta(self, kwargs: dict, logits: torch.Tensor, logits_name: str = "transformer logits"): + def _get_logits_dims(self, kwargs: dict, logits: torch.Tensor): vocab_dim = self._tensor_space.get_tensor_dim( LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp ) @@ -320,7 +328,7 @@ def _get_logits_meta(self, kwargs: dict, logits: torch.Tensor, logits_name: str if self._sequence_parallel_logits else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) ) - return TensorMeta.from_dims(tuple(dims), tensor_name=logits_name, dtype=logits.dtype) + return dims, sequence_index def _logits_cross_entropy_forward_backward( self, @@ -362,7 +370,11 @@ def _logits_cross_entropy_forward_backward( "", logits, level=self._debug_transformer, - meta=self._get_logits_meta(kwargs, logits), + meta=TensorMeta.from_dims( + tuple(self._get_logits_dims(kwargs, logits)[0]), + tensor_name="transformer logits", + dtype=logits.dtype, + ), distributed=self._tensor_space.distributed, scale=self._logits_scale_factor, ) From b9851c23bb063647ad5fc4d5df3d348d8c177869 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 18 Aug 2025 16:43:35 +0000 Subject: [PATCH 4/9] added more broadcast primitives and changed _object_to_tensor to be faster according to torch src --- fast_llm/core/distributed.py | 54 +++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 86f8e7297..16b7c3921 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -107,6 +107,54 @@ def broadcast_scalar( return tensor.item() +def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None, src: int = 0) -> typing.Any: + """ + Broadcasts a Python object from src rank to all other ranks in the ProcessGroup. + Returns the object on all ranks. + """ + assert group is not None + + if group.rank() == src: + tensor = _object_to_tensor(input_object) + size = tensor.numel() + broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast_tensor.copy_(tensor) + broadcast_scalar(size, torch.int64, group, src) + broadcast(broadcast_tensor, src, group) + return input_object + else: + size = int(broadcast_scalar(None, torch.int64, group, src)) + output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast(output_tensor, src, group) + return _tensor_to_object(output_tensor) + + +def broadcast_optional(tensor: torch.Tensor | None, group: ProcessGroup = None, src: int = 0) -> torch.Tensor: + """ + Broadcasts an optional tensor of size, shape, and dtype unknown in advance. + Returns the tensor on all ranks or None if no tensor was sent. + """ + assert group is not None + + if group.rank() == src: + has_tensor = tensor is not None + if has_tensor: + meta = (has_tensor, tensor.shape, tensor.dtype) + else: + meta = (has_tensor, None, None) + broadcast_object(meta, group, src) + if has_tensor: + broadcast(tensor.to(torch.cuda.current_device()), src, group) + return tensor + else: + has_tensor, shape, dtype = broadcast_object(None, group, src) + if not has_tensor: + return None + output_tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device()) + broadcast(output_tensor, src, group) + return output_tensor + + def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: assert group is not None work = group.send([tensor], dst, tag) @@ -186,7 +234,11 @@ def scatter( def _object_to_tensor(obj: typing.Any) -> torch.Tensor: f = io.BytesIO() pickle.Pickler(f).dump(obj) - return torch.tensor(torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8), dtype=torch.uint8) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + return torch.ByteTensor(byte_storage) def _tensor_to_object(tensor: torch.Tensor) -> typing.Any: From 750ea1c43c8435e36a2375d3dd7e1f5a48f636cf Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 18 Aug 2025 16:44:50 +0000 Subject: [PATCH 5/9] added support to TP in forward for generate --- fast_llm/engine/inference/huggingface.py | 110 ++++++++++++++++++++++- fast_llm/models/gpt/huggingface.py | 2 +- 2 files changed, 110 insertions(+), 2 deletions(-) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index df9361c8a..e200fdd07 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -7,7 +7,9 @@ import transformers.generation.utils import transformers.modeling_outputs +from fast_llm.core.distributed import broadcast_object, broadcast_optional, safe_barrier from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode @@ -112,7 +114,7 @@ def _init_weights(self, module) -> None: class HuggingfaceBaseModelForCausalLM(HuggingfacePreTrainedModel, transformers.generation.utils.GenerationMixin): - def forward( + def inner_forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, @@ -127,3 +129,109 @@ def forward( ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: # Meant to be overridden in derived classes raise NotImplementedError() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values=None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + coordinator_forward: bool = False, + continue_work: bool = True, + ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast | None: + """ + Forward pass compatible with HuggingFace forward. + + Additional arguments: + coordinator_forward (bool): + If True, only the TP group coordinator (rank 0) should call forward; + other ranks must call worker_forward. + If False, all TP group ranks call forward independently and return logits. + continue_work (bool): Whether to continue processing in a TP group. + Only applies for coordinator_forward=True. + + Notes: + - In coordinator_forward=True mode, forward on rank 0 distributes data to other ranks. + - After processing, the coordinator (rank 0) must call `stop_workers()` before continuing, + to unblock worker_forward on other ranks. + - This mode augments HuggingFace generate with tensor-parallel capability. + """ + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + + if coordinator_forward and distributed.world_group and distributed.tensor_group: + assert distributed.tensor_group.rank() == 0 + assert past_key_values is None and not use_cache + + broadcast_optional(input_ids, distributed.tensor_group, 0) + broadcast_optional(attention_mask, distributed.tensor_group, 0) + broadcast_optional(position_ids, distributed.tensor_group, 0) + broadcast_optional(inputs_embeds, distributed.tensor_group, 0) + broadcast_optional(labels, distributed.tensor_group, 0) + + broadcast_object( + (past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work), + distributed.tensor_group, + 0, + ) + + if not coordinator_forward or continue_work: + return self.inner_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + ) + + return None + + def worker_forward(self): + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + assert distributed.world_group and distributed.tensor_group and distributed.tensor_group.rank() != 0 + + while True: + input_ids = broadcast_optional(None, distributed.tensor_group, 0) + attention_mask = broadcast_optional(None, distributed.tensor_group, 0) + position_ids = broadcast_optional(None, distributed.tensor_group, 0) + inputs_embeds = broadcast_optional(None, distributed.tensor_group, 0) + labels = broadcast_optional(None, distributed.tensor_group, 0) + + past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work = ( + broadcast_object(None, distributed.tensor_group, 0) + ) + if not continue_work: + break + + self.inner_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + ) + + safe_barrier(distributed.world_group, "forward_end") + + def stop_workers(self): + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + # On single gpu or no tp, no worker_forward to stop + if distributed.world_group is None or distributed.tensor_group is None: + return + self.forward(coordinator_forward=True, continue_work=False) + safe_barrier(distributed.world_group, "forward_end") diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index d8100c4e4..0c252a877 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -32,7 +32,7 @@ class HuggingfaceGPTModelForCausalLM(HuggingfaceBaseModelForCausalLM): # _supports_cache_class = False # _tied_weights_keys = [] - def forward( + def inner_forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, From 0f196dae4367ca77637591b299e0c364b320f0b9 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 18 Aug 2025 16:46:28 +0000 Subject: [PATCH 6/9] added suppport to other parallelism additionally to data parallelism --- .../evaluation/lm_eval/fast_llm_wrapper.py | 184 +++++++++++------- 1 file changed, 114 insertions(+), 70 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 8f4dffedf..ebc2b9a65 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -39,17 +39,20 @@ def __init__( # === Distributed setup === self._rank = 0 # For lm_eval: always run on main rank - self._world_size = 1 + self._world_size = 1 # For lm_eval: always world size 1 + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed + self._world_group = self._distributed.world_group + if ( self._distributed.config.sequence_data_rank == 0 and self._distributed.config.pipeline_rank == 0 and self._distributed.config.tensor_rank == 0 ): - self._group = self._distributed.batch_data_group + self._leading_batch_data_group = self._distributed.batch_data_group else: - self._group = torch.distributed.GroupMember.NON_GROUP_MEMBER + self._leading_batch_data_group = None # === Model & tokenizer setup === self._model = model @@ -171,11 +174,15 @@ def run(self, cli_args: list[str], completed_steps: int, run_index: int): completed_steps, ) else: - self.worker_model_invoke() + # On the rest of the bath data group leaders, we run the full generate/forward pass + # On all other ranks, we only invoke worker_forward + if self._leading_batch_data_group: + self.worker_model_invoke() + else: + self._model.worker_forward() - # TODO: do we need it here as self.stop_workers() and self.worker_model_invoke() - # already have barrier - safe_barrier(self._distributed.world_group, f"lm_eval Run end") + # Model forward workers end earlier, so sync here for all gpus + safe_barrier(self._world_group, f"lm_eval Run end") def _model_invoke( self, @@ -185,39 +192,44 @@ def _model_invoke( max_length, stop, generate: bool, - continue_generate: bool, + continue_work: bool, **generation_kwargs, ): # TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple. # Messages could include types like logits, generate, finished. - # Group is always None if world size is 1 - if self._group is None: - # Must not be called with continue_generate false on one process - assert continue_generate + # Call directly if on one gpu or data group size is 1 + if self._world_group is None or self._leading_batch_data_group is None: + # Must not be called with continue_work false on one gpu + assert self._world_group or continue_work + # Still call then continue_work false to stop model forward workers return self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs ) - world_size = self._group.size() + assert self._world_group.rank() == 0 - assert self._group.rank() == 0 + batch_data_parallel_size = self._leading_batch_data_group.size() - if continue_generate: + if continue_work: assert input_ids is not None if generate: assert max_length is not None and stop is not None - # always divide by world_size, if not full batch, some ranks will get less work or not at all - assert self._batch_size % world_size == 0 - step = self._batch_size // world_size + # Always divide by batch_data_parallel_size, if not full batch, some ranks will get less work or not at all. + assert self._batch_size % batch_data_parallel_size == 0 + step = self._batch_size // batch_data_parallel_size - input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + # Data is send to every rank and micro batches are repeated for the same batch_data_parallel rank. + input_ids = [input_ids[i * step : (i + 1) * step] for i in range(batch_data_parallel_size)] attention_mask = [ - attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None - for i in range(world_size) + (attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None) + for i in range(batch_data_parallel_size) + ] + labels = [ + (labels[i * step : (i + 1) * step] if labels is not None else None) + for i in range(batch_data_parallel_size) ] - labels = [labels[i * step : (i + 1) * step] if labels is not None else None for i in range(world_size)] scatter_list = [ [ @@ -227,36 +239,40 @@ def _model_invoke( max_length, stop, generate, - continue_generate, + continue_work, generation_kwargs, ] - for i in range(world_size) + for i in range(batch_data_parallel_size) ] else: - scatter_list = [[None, None, None, None, None, None, False, None] for _ in range(world_size)] + scatter_list = [[[], None, None, None, None, None, False, {}] for _ in range(batch_data_parallel_size)] - input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, generation_kwargs = ( scatter_object( scatter_list, - group=self._group, + group=self._leading_batch_data_group, ) ) - if not continue_generate: + # Always call inner function to propagate stop signal to TP workers if continue_work is False + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs + ) + + if not continue_work: return None assert len(input_ids) > 0 - result = self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs - ) + # Data is gathered only from the batch_data_group leaders, + # since the forward pass produces the same output on all ranks within the group. + gather_list = gather_object(result, group=self._leading_batch_data_group) - gather_list = gather_object(result, group=self._group) - # Clean gather list from empty shards + # Clean gather list from empty shards (from not full batches). gather_list = [el for el in gather_list if len(el) > 0] # If it was model generate tensors could be of different length - # so we aggregate results to list instead of a tensor + # so we aggregate results to list instead of a tensor. if generate: result = sum((el.tolist() for el in gather_list), []) else: @@ -266,57 +282,83 @@ def _model_invoke( return result def worker_model_invoke(self): - assert self._group is not None - # if isinstance(self.group, dist.ProcessGroup): - if not isinstance(self._group, int): - # groups is None for world_size 1 - assert self._group.rank() != 0 - # on worker ranks the function need to wait to be called multiple times - while True: - input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( - scatter_object( - None, - group=self._group, - ) + # Group is None for world_size 1 and this function must not be called for world_size 1. + assert self._world_group + assert self._leading_batch_data_group + # The function must not be called on the main rank. + assert self._world_group.rank() != 0 + + # On worker ranks the function need to wait to be called multiple times + while True: + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, generation_kwargs = ( + scatter_object( + None, + group=self._leading_batch_data_group, ) + ) + + # Always call inner function to propagate stop signal to TP workers if continue_work is False + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs + ) - # Stop signal was send, end waiting/processing loop - if not continue_generate: - break + # Stop signal was send, end waiting/processing loop + if not continue_work: + break - # if some data was received, work, otherwise return empty tensor - if len(input_ids) > 0: - result = self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs - ) - else: - result = input_ids + # If some data was received, return processed results, otherwise return empty tensor + if len(input_ids) == 0: + result = input_ids - gather_object(result, group=self._group) - else: - # TODO: implement distributed model support - assert self._group == torch.distributed.GroupMember.NON_GROUP_MEMBER - safe_barrier(self._distributed.world_group, "lm_eval_end") + gather_object(result, group=self._leading_batch_data_group) + + safe_barrier(self._leading_batch_data_group, "lm_eval_end") def stop_workers(self): # Group is always None if world size is 1 - if self._group is None: + if self._world_group is None: return - self._model_invoke(None, None, None, None, None, None, continue_generate=False) - safe_barrier(self._distributed.world_group, "lm_eval_end") + + self._model_invoke([], None, None, None, None, None, continue_work=False) + + # Only if data group size > 1 worker_model_invoke is called and need to be synced here + if self._leading_batch_data_group: + safe_barrier(self._leading_batch_data_group, "lm_eval_end") def _model_invoke_inner( - self, input_ids, attention_mask, labels, max_length, stop, generate: bool, **generation_kwargs + self, + input_ids, + attention_mask, + labels, + max_length, + stop, + generate: bool, + continue_work: bool, + **generation_kwargs, ): + # If stopping, stop model forward workers and return. + if not continue_work: + # continue_work=False should not occur on a single GPU. + # This function must be called on batch data parallel group leaders. + # If there is only one data rank, the leader is global rank 0 and the data group will be None. + assert self._world_group is not None + assert self._world_group.rank() == 0 or self._leading_batch_data_group + self._model.stop_workers() + return None + + # If input_ids is empty, there is no work to process - return early + if len(input_ids) == 0: + # Receiving no work can only happen on non-zero ranks + assert self._world_group is not None and self._world_group.rank() != 0 + return None + if generate: return self._model_generate_inner(input_ids, attention_mask, max_length, stop, **generation_kwargs) else: return self._model_call_inner(input_ids, attention_mask, labels) def _model_call(self, input_ids, attention_mask=None, labels=None): - return self._model_invoke( - input_ids, attention_mask, labels, None, None, generate=False, continue_generate=True - ) + return self._model_invoke(input_ids, attention_mask, labels, None, None, generate=False, continue_work=True) def _model_generate(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): return self._model_invoke( @@ -326,7 +368,7 @@ def _model_generate(self, input_ids, attention_mask, max_length, stop, **generat max_length, stop, generate=True, - continue_generate=True, + continue_work=True, **generation_kwargs, ) @@ -370,6 +412,7 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): output_attentions=False, output_hidden_states=False, return_dict=True, + coordinator_forward=True, ).logits.cpu() def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): @@ -398,6 +441,7 @@ def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **g stopping_criteria=stopping_criteria, pad_token_id=self._tokenizer.pad_token_id, use_cache=False, + coordinator_forward=True, **generation_kwargs, ) From 82b901dc63a1c07c10b7d4170191a999d2d64feb Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 18 Aug 2025 17:34:42 +0000 Subject: [PATCH 7/9] removed out of date comment --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index ebc2b9a65..34d292293 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -264,8 +264,6 @@ def _model_invoke( assert len(input_ids) > 0 - # Data is gathered only from the batch_data_group leaders, - # since the forward pass produces the same output on all ranks within the group. gather_list = gather_object(result, group=self._leading_batch_data_group) # Clean gather list from empty shards (from not full batches). From 543f3d661777cf16a9b78f322d2ae817fc0489a6 Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 19 Aug 2025 19:05:08 +0000 Subject: [PATCH 8/9] added extended wait in key places, fix to right batch config, fix move to right gpu after scatter --- fast_llm/engine/evaluation/config.py | 7 +++ .../engine/evaluation/lm_eval/evaluator.py | 2 + .../evaluation/lm_eval/fast_llm_wrapper.py | 48 ++++++++++++++++++- fast_llm/engine/inference/huggingface.py | 19 ++++++-- 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 04e4227f1..70cc278ed 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -103,6 +103,13 @@ class LmEvalEvaluatorConfig(EvaluatorConfig): " If not set, it is inferred from the Fast-LLM model config or tokenizer.", ) + communication_timeout_sec: float = Field( + default=600.0, + desc="Maximum wait time (in seconds) for tensor-parallel or data-parallel model " + "operations such as forward, generate, or gathering data. Needed because some " + "ranks may have no data or post-processing can be slow, exceeding the default 60s timeout.", + ) + def get_evaluator( self, name: str, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 162ceaf60..c444dd111 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -68,6 +68,8 @@ def setup( add_bos_token=self._config.add_bos_token, prefix_token_id=self._config.prefix_token_id, max_length=self._config.max_length, + batch_config=self._batch_config, + communication_timeout_sec=self._config.communication_timeout_sec, ) self._is_setup = True diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 34d292293..e6f0ebba5 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,6 +16,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.transformer.rotary.config import NoRotaryConfig logger = logging.getLogger(__name__) @@ -34,6 +35,8 @@ def __init__( add_bos_token: bool | None = False, prefix_token_id: int | None = None, max_length: int | None = None, + batch_config: BatchConfig | None = None, + communication_timeout_sec: float = 600.0, ): super().__init__() @@ -41,6 +44,8 @@ def __init__( self._rank = 0 # For lm_eval: always run on main rank self._world_size = 1 # For lm_eval: always world size 1 + self.communication_timeout_sec = communication_timeout_sec + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed self._world_group = self._distributed.world_group @@ -79,7 +84,10 @@ def __init__( # === Batch configuration === self._batch_schedule = 1 self._batch_sizes = {} # Not used dynamically by lm_eval - self._batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size + + # NOTE: We can not take batch configuration from inference runner as it has a dummy batch config + self._batch_size_per_gpu = batch_config.micro_batch_size if batch_config else 1 + self._batch_size = self._batch_size_per_gpu * self._distributed.config.batch_data_parallel self._max_batch_size = self._batch_size @@ -179,7 +187,7 @@ def run(self, cli_args: list[str], completed_steps: int, run_index: int): if self._leading_batch_data_group: self.worker_model_invoke() else: - self._model.worker_forward() + self._model.worker_forward(communication_timeout_sec=self.communication_timeout_sec) # Model forward workers end earlier, so sync here for all gpus safe_barrier(self._world_group, f"lm_eval Run end") @@ -247,6 +255,9 @@ def _model_invoke( else: scatter_list = [[[], None, None, None, None, None, False, {}] for _ in range(batch_data_parallel_size)] + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(self._leading_batch_data_group, "model_invoke_wait", timeout=self.communication_timeout_sec) input_ids, attention_mask, labels, max_length, stop, generate, continue_work, generation_kwargs = ( scatter_object( scatter_list, @@ -264,6 +275,11 @@ def _model_invoke( assert len(input_ids) > 0 + # At the end, some data-parallel ranks may have no data, so the wait can + # exceed the standard 60s timeout. + safe_barrier( + self._leading_batch_data_group, "model_invoke_gather_wait", timeout=self.communication_timeout_sec + ) gather_list = gather_object(result, group=self._leading_batch_data_group) # Clean gather list from empty shards (from not full batches). @@ -286,14 +302,35 @@ def worker_model_invoke(self): # The function must not be called on the main rank. assert self._world_group.rank() != 0 + device = torch.cuda.current_device() + # On worker ranks the function need to wait to be called multiple times while True: + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(self._leading_batch_data_group, "model_invoke_wait", timeout=self.communication_timeout_sec) input_ids, attention_mask, labels, max_length, stop, generate, continue_work, generation_kwargs = ( scatter_object( None, group=self._leading_batch_data_group, ) ) + # NOTE: scatter_object keeps tensors on the same device as the source, + # so they must be moved to the current device. + # TODO: With scatter_object, tensors are copied GPU → CPU → GPU, then scattered, + # and finally copied again to the correct GPU here. We already have a scatter + # primitive for tensors of known size and type; we need to extend it to + # handle optional tensors of unknown type and size directly (src GPU → dst GPU); + # and use it for scattering tensors like input_ids, attention_mask, labels. + if isinstance(input_ids, torch.Tensor): + input_ids = input_ids.to(device) + if isinstance(attention_mask, torch.Tensor): + attention_mask = attention_mask.to(device) + if isinstance(labels, torch.Tensor): + labels = labels.to(device) + + if continue_work: + logger.info(f"worker_model_invoke: input_id device {input_ids.device}, shape {input_ids.shape}") # Always call inner function to propagate stop signal to TP workers if continue_work is False result = self._model_invoke_inner( @@ -308,6 +345,11 @@ def worker_model_invoke(self): if len(input_ids) == 0: result = input_ids + # At the end, some data-parallel ranks may have no data, so the wait can + # exceed the standard 60s timeout. + safe_barrier( + self._leading_batch_data_group, "model_invoke_gather_wait", timeout=self.communication_timeout_sec + ) gather_object(result, group=self._leading_batch_data_group) safe_barrier(self._leading_batch_data_group, "lm_eval_end") @@ -411,6 +453,7 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): output_hidden_states=False, return_dict=True, coordinator_forward=True, + communication_timeout_sec=self.communication_timeout_sec, ).logits.cpu() def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): @@ -440,6 +483,7 @@ def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **g pad_token_id=self._tokenizer.pad_token_id, use_cache=False, coordinator_forward=True, + communication_timeout_sec=self.communication_timeout_sec, **generation_kwargs, ) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index e200fdd07..18e8e96ce 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -143,6 +143,7 @@ def forward( output_hidden_states: bool | None = None, return_dict: bool | None = None, coordinator_forward: bool = False, + communication_timeout_sec: float = 600.0, continue_work: bool = True, ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast | None: """ @@ -168,6 +169,10 @@ def forward( assert distributed.tensor_group.rank() == 0 assert past_key_values is None and not use_cache + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec) + broadcast_optional(input_ids, distributed.tensor_group, 0) broadcast_optional(attention_mask, distributed.tensor_group, 0) broadcast_optional(position_ids, distributed.tensor_group, 0) @@ -196,11 +201,18 @@ def forward( return None - def worker_forward(self): + def worker_forward( + self, + communication_timeout_sec: float = 600.0, + ): distributed: Distributed = self._inference_runner._fast_llm_model.distributed assert distributed.world_group and distributed.tensor_group and distributed.tensor_group.rank() != 0 while True: + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec) + input_ids = broadcast_optional(None, distributed.tensor_group, 0) attention_mask = broadcast_optional(None, distributed.tensor_group, 0) position_ids = broadcast_optional(None, distributed.tensor_group, 0) @@ -210,6 +222,7 @@ def worker_forward(self): past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work = ( broadcast_object(None, distributed.tensor_group, 0) ) + if not continue_work: break @@ -226,7 +239,7 @@ def worker_forward(self): return_dict, ) - safe_barrier(distributed.world_group, "forward_end") + safe_barrier(distributed.world_group, "forward_work_end") def stop_workers(self): distributed: Distributed = self._inference_runner._fast_llm_model.distributed @@ -234,4 +247,4 @@ def stop_workers(self): if distributed.world_group is None or distributed.tensor_group is None: return self.forward(coordinator_forward=True, continue_work=False) - safe_barrier(distributed.world_group, "forward_end") + safe_barrier(distributed.world_group, "forward_work_end") From be8050ce08f83a85d4bb5c2216c95329301ff0e5 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 20 Aug 2025 11:57:48 +0000 Subject: [PATCH 9/9] added more docs --- fast_llm/engine/inference/huggingface.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 18e8e96ce..b634f8a4d 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -154,6 +154,9 @@ def forward( If True, only the TP group coordinator (rank 0) should call forward; other ranks must call worker_forward. If False, all TP group ranks call forward independently and return logits. + communication_timeout_sec (float): Maximum time (in seconds) to wait for the start of + forward or for a stop signal to worker ranks before timing out in worker_forward. + Must match the value passed to worker_forward. continue_work (bool): Whether to continue processing in a TP group. Only applies for coordinator_forward=True. @@ -205,6 +208,29 @@ def worker_forward( self, communication_timeout_sec: float = 600.0, ): + """ + Run the forward loop on worker ranks in coordinated mode. + + This function must be called on all worker ranks (i.e., all ranks except the + coordinator/leading data-parallel rank). In coordinated mode, the coordinator + rank calls `forward`, which distributes inputs to workers. Each worker then + receives its inputs and runs a forward pass. + + Workers stay in this loop until a stop signal is broadcast, which happens when + the coordinator rank calls `stop_workers`. + + Args: + communication_timeout_sec (float): Maximum time (in seconds) to wait for the + start of a forward call or for a stop signal from the coordinator before + timing out. Must match the value passed to `forward`. + + Notes: + - Coordinator rank: calls `forward` in coordinated mode and later + `stop_workers` to unblock workers. + - Worker ranks: call `worker_forward` once and remain inside the loop, + executing forward passes with broadcasted inputs until a stop signal + is received. + """ distributed: Distributed = self._inference_runner._fast_llm_model.distributed assert distributed.world_group and distributed.tensor_group and distributed.tensor_group.rank() != 0