diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 86f8e7297..16b7c3921 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -107,6 +107,54 @@ def broadcast_scalar( return tensor.item() +def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None, src: int = 0) -> typing.Any: + """ + Broadcasts a Python object from src rank to all other ranks in the ProcessGroup. + Returns the object on all ranks. + """ + assert group is not None + + if group.rank() == src: + tensor = _object_to_tensor(input_object) + size = tensor.numel() + broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast_tensor.copy_(tensor) + broadcast_scalar(size, torch.int64, group, src) + broadcast(broadcast_tensor, src, group) + return input_object + else: + size = int(broadcast_scalar(None, torch.int64, group, src)) + output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast(output_tensor, src, group) + return _tensor_to_object(output_tensor) + + +def broadcast_optional(tensor: torch.Tensor | None, group: ProcessGroup = None, src: int = 0) -> torch.Tensor: + """ + Broadcasts an optional tensor of size, shape, and dtype unknown in advance. + Returns the tensor on all ranks or None if no tensor was sent. + """ + assert group is not None + + if group.rank() == src: + has_tensor = tensor is not None + if has_tensor: + meta = (has_tensor, tensor.shape, tensor.dtype) + else: + meta = (has_tensor, None, None) + broadcast_object(meta, group, src) + if has_tensor: + broadcast(tensor.to(torch.cuda.current_device()), src, group) + return tensor + else: + has_tensor, shape, dtype = broadcast_object(None, group, src) + if not has_tensor: + return None + output_tensor = torch.empty(shape, dtype=dtype, device=torch.cuda.current_device()) + broadcast(output_tensor, src, group) + return output_tensor + + def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: assert group is not None work = group.send([tensor], dst, tag) @@ -186,7 +234,11 @@ def scatter( def _object_to_tensor(obj: typing.Any) -> torch.Tensor: f = io.BytesIO() pickle.Pickler(f).dump(obj) - return torch.tensor(torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8), dtype=torch.uint8) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + return torch.ByteTensor(byte_storage) def _tensor_to_object(tensor: torch.Tensor) -> typing.Any: diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 4eb5d71df..4f035e174 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -98,6 +98,13 @@ class LmEvalEvaluatorConfig(EvaluatorConfig): " If not set, it is inferred from the Fast-LLM model config or tokenizer.", ) + communication_timeout_sec: float = Field( + default=600.0, + desc="Maximum wait time (in seconds) for tensor-parallel or data-parallel model " + "operations such as forward, generate, or gathering data. Needed because some " + "ranks may have no data or post-processing can be slow, exceeding the default 60s timeout.", + ) + def get_evaluator( self, name: str, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 9040b11b4..14aed65c4 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -66,6 +66,8 @@ def setup( add_bos_token=self._config.add_bos_token, prefix_token_id=self._config.prefix_token_id, max_length=self._config.max_length, + batch_config=self._batch_config, + communication_timeout_sec=self._config.communication_timeout_sec, ) self._is_setup = True diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 3a606b41d..bc42515e7 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,6 +16,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.attention.rotary.config import NoRotaryConfig logger = logging.getLogger(__name__) @@ -34,22 +35,29 @@ def __init__( add_bos_token: bool | None = False, prefix_token_id: int | None = None, max_length: int | None = None, + batch_config: BatchConfig | None = None, + communication_timeout_sec: float = 600.0, ): super().__init__() # === Distributed setup === self._rank = 0 # For lm_eval: always run on main rank - self._world_size = 1 + self._world_size = 1 # For lm_eval: always world size 1 + + self.communication_timeout_sec = communication_timeout_sec + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed + self._world_group = self._distributed.world_group + if ( self._distributed.config.sequence_data_rank == 0 and self._distributed.config.pipeline_rank == 0 and self._distributed.config.tensor_rank == 0 ): - self._group = self._distributed.batch_data_group + self._leading_batch_data_group = self._distributed.batch_data_group else: - self._group = torch.distributed.GroupMember.NON_GROUP_MEMBER + self._leading_batch_data_group = None # === Model & tokenizer setup === self._model = model @@ -76,7 +84,10 @@ def __init__( # === Batch configuration === self._batch_schedule = 1 self._batch_sizes = {} # Not used dynamically by lm_eval - self._batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size + + # NOTE: We can not take batch configuration from inference runner as it has a dummy batch config + self._batch_size_per_gpu = batch_config.micro_batch_size if batch_config else 1 + self._batch_size = self._batch_size_per_gpu * self._distributed.config.batch_data_parallel self._max_batch_size = self._batch_size @@ -171,11 +182,15 @@ def run(self, cli_args: list[str], completed_steps: int, run_index: int): completed_steps, ) else: - self.worker_model_invoke() + # On the rest of the bath data group leaders, we run the full generate/forward pass + # On all other ranks, we only invoke worker_forward + if self._leading_batch_data_group: + self.worker_model_invoke() + else: + self._model.worker_forward(communication_timeout_sec=self.communication_timeout_sec) - # TODO: do we need it here as self.stop_workers() and self.worker_model_invoke() - # already have barrier - safe_barrier(self._distributed.world_group, f"lm_eval Run end") + # Model forward workers end earlier, so sync here for all gpus + safe_barrier(self._world_group, f"lm_eval Run end") def _model_invoke( self, @@ -185,39 +200,44 @@ def _model_invoke( max_length, stop, generate: bool, - continue_generate: bool, + continue_work: bool, **generation_kwargs, ): # TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple. # Messages could include types like logits, generate, finished. - # Group is always None if world size is 1 - if self._group is None: - # Must not be called with continue_generate false on one process - assert continue_generate + # Call directly if on one gpu or data group size is 1 + if self._world_group is None or self._leading_batch_data_group is None: + # Must not be called with continue_work false on one gpu + assert self._world_group or continue_work + # Still call then continue_work false to stop model forward workers return self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs ) - world_size = self._group.size() + assert self._world_group.rank() == 0 - assert self._group.rank() == 0 + batch_data_parallel_size = self._leading_batch_data_group.size() - if continue_generate: + if continue_work: assert input_ids is not None if generate: assert max_length is not None and stop is not None - # always divide by world_size, if not full batch, some ranks will get less work or not at all - assert self._batch_size % world_size == 0 - step = self._batch_size // world_size + # Always divide by batch_data_parallel_size, if not full batch, some ranks will get less work or not at all. + assert self._batch_size % batch_data_parallel_size == 0 + step = self._batch_size // batch_data_parallel_size - input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + # Data is send to every rank and micro batches are repeated for the same batch_data_parallel rank. + input_ids = [input_ids[i * step : (i + 1) * step] for i in range(batch_data_parallel_size)] attention_mask = [ - attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None - for i in range(world_size) + (attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None) + for i in range(batch_data_parallel_size) + ] + labels = [ + (labels[i * step : (i + 1) * step] if labels is not None else None) + for i in range(batch_data_parallel_size) ] - labels = [labels[i * step : (i + 1) * step] if labels is not None else None for i in range(world_size)] scatter_list = [ [ @@ -227,36 +247,46 @@ def _model_invoke( max_length, stop, generate, - continue_generate, + continue_work, generation_kwargs, ] - for i in range(world_size) + for i in range(batch_data_parallel_size) ] else: - scatter_list = [[None, None, None, None, None, None, False, None] for _ in range(world_size)] + scatter_list = [[[], None, None, None, None, None, False, {}] for _ in range(batch_data_parallel_size)] - input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(self._leading_batch_data_group, "model_invoke_wait", timeout=self.communication_timeout_sec) + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, generation_kwargs = ( scatter_object( scatter_list, - group=self._group, + group=self._leading_batch_data_group, ) ) - if not continue_generate: + # Always call inner function to propagate stop signal to TP workers if continue_work is False + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs + ) + + if not continue_work: return None assert len(input_ids) > 0 - result = self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + # At the end, some data-parallel ranks may have no data, so the wait can + # exceed the standard 60s timeout. + safe_barrier( + self._leading_batch_data_group, "model_invoke_gather_wait", timeout=self.communication_timeout_sec ) + gather_list = gather_object(result, group=self._leading_batch_data_group) - gather_list = gather_object(result, group=self._group) - # Clean gather list from empty shards + # Clean gather list from empty shards (from not full batches). gather_list = [el for el in gather_list if len(el) > 0] # If it was model generate tensors could be of different length - # so we aggregate results to list instead of a tensor + # so we aggregate results to list instead of a tensor. if generate: result = sum((el.tolist() for el in gather_list), []) else: @@ -266,57 +296,109 @@ def _model_invoke( return result def worker_model_invoke(self): - assert self._group is not None - # if isinstance(self.group, dist.ProcessGroup): - if not isinstance(self._group, int): - # groups is None for world_size 1 - assert self._group.rank() != 0 - # on worker ranks the function need to wait to be called multiple times - while True: - input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( - scatter_object( - None, - group=self._group, - ) + # Group is None for world_size 1 and this function must not be called for world_size 1. + assert self._world_group + assert self._leading_batch_data_group + # The function must not be called on the main rank. + assert self._world_group.rank() != 0 + + device = torch.cuda.current_device() + + # On worker ranks the function need to wait to be called multiple times + while True: + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(self._leading_batch_data_group, "model_invoke_wait", timeout=self.communication_timeout_sec) + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, generation_kwargs = ( + scatter_object( + None, + group=self._leading_batch_data_group, ) + ) + # NOTE: scatter_object keeps tensors on the same device as the source, + # so they must be moved to the current device. + # TODO: With scatter_object, tensors are copied GPU → CPU → GPU, then scattered, + # and finally copied again to the correct GPU here. We already have a scatter + # primitive for tensors of known size and type; we need to extend it to + # handle optional tensors of unknown type and size directly (src GPU → dst GPU); + # and use it for scattering tensors like input_ids, attention_mask, labels. + if isinstance(input_ids, torch.Tensor): + input_ids = input_ids.to(device) + if isinstance(attention_mask, torch.Tensor): + attention_mask = attention_mask.to(device) + if isinstance(labels, torch.Tensor): + labels = labels.to(device) + + if continue_work: + logger.info(f"worker_model_invoke: input_id device {input_ids.device}, shape {input_ids.shape}") + + # Always call inner function to propagate stop signal to TP workers if continue_work is False + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, continue_work, **generation_kwargs + ) - # Stop signal was send, end waiting/processing loop - if not continue_generate: - break + # Stop signal was send, end waiting/processing loop + if not continue_work: + break - # if some data was received, work, otherwise return empty tensor - if len(input_ids) > 0: - result = self._model_invoke_inner( - input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs - ) - else: - result = input_ids + # If some data was received, return processed results, otherwise return empty tensor + if len(input_ids) == 0: + result = input_ids - gather_object(result, group=self._group) - else: - # TODO: implement distributed model support - assert self._group == torch.distributed.GroupMember.NON_GROUP_MEMBER - safe_barrier(self._distributed.world_group, "lm_eval_end") + # At the end, some data-parallel ranks may have no data, so the wait can + # exceed the standard 60s timeout. + safe_barrier( + self._leading_batch_data_group, "model_invoke_gather_wait", timeout=self.communication_timeout_sec + ) + gather_object(result, group=self._leading_batch_data_group) + + safe_barrier(self._leading_batch_data_group, "lm_eval_end") def stop_workers(self): # Group is always None if world size is 1 - if self._group is None: + if self._world_group is None: return - self._model_invoke(None, None, None, None, None, None, continue_generate=False) - safe_barrier(self._distributed.world_group, "lm_eval_end") + + self._model_invoke([], None, None, None, None, None, continue_work=False) + + # Only if data group size > 1 worker_model_invoke is called and need to be synced here + if self._leading_batch_data_group: + safe_barrier(self._leading_batch_data_group, "lm_eval_end") def _model_invoke_inner( - self, input_ids, attention_mask, labels, max_length, stop, generate: bool, **generation_kwargs + self, + input_ids, + attention_mask, + labels, + max_length, + stop, + generate: bool, + continue_work: bool, + **generation_kwargs, ): + # If stopping, stop model forward workers and return. + if not continue_work: + # continue_work=False should not occur on a single GPU. + # This function must be called on batch data parallel group leaders. + # If there is only one data rank, the leader is global rank 0 and the data group will be None. + assert self._world_group is not None + assert self._world_group.rank() == 0 or self._leading_batch_data_group + self._model.stop_workers() + return None + + # If input_ids is empty, there is no work to process - return early + if len(input_ids) == 0: + # Receiving no work can only happen on non-zero ranks + assert self._world_group is not None and self._world_group.rank() != 0 + return None + if generate: return self._model_generate_inner(input_ids, attention_mask, max_length, stop, **generation_kwargs) else: return self._model_call_inner(input_ids, attention_mask, labels) def _model_call(self, input_ids, attention_mask=None, labels=None): - return self._model_invoke( - input_ids, attention_mask, labels, None, None, generate=False, continue_generate=True - ) + return self._model_invoke(input_ids, attention_mask, labels, None, None, generate=False, continue_work=True) def _model_generate(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): return self._model_invoke( @@ -326,7 +408,7 @@ def _model_generate(self, input_ids, attention_mask, max_length, stop, **generat max_length, stop, generate=True, - continue_generate=True, + continue_work=True, **generation_kwargs, ) @@ -370,6 +452,8 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): output_attentions=False, output_hidden_states=False, return_dict=True, + coordinator_forward=True, + communication_timeout_sec=self.communication_timeout_sec, ).logits.cpu() def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): @@ -398,6 +482,8 @@ def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **g stopping_criteria=stopping_criteria, pad_token_id=self._tokenizer.pad_token_id, use_cache=False, + coordinator_forward=True, + communication_timeout_sec=self.communication_timeout_sec, **generation_kwargs, ) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 54a82492b..b634f8a4d 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -7,7 +7,9 @@ import transformers.generation.utils import transformers.modeling_outputs +from fast_llm.core.distributed import broadcast_object, broadcast_optional, safe_barrier from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode @@ -57,12 +59,12 @@ def __init__( # or set existing model which also must be setup, so, do not accept not setup model assert fast_llm_model.is_setup - # We only support data parallel for now - Assert.eq(fast_llm_model.distributed.config.model_parallel, 1) - Assert.eq(fast_llm_model.distributed.config.sequence_data_parallel, 1) - self._inference_runner.setup() + # We only support data parallel and tensor parallel for now + Assert.eq(fast_llm_model.distributed.config.pipeline_parallel, 1) + Assert.eq(fast_llm_model.distributed.config.sequence_data_parallel, 1) + # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = fast_llm_model.base_model @@ -112,7 +114,7 @@ def _init_weights(self, module) -> None: class HuggingfaceBaseModelForCausalLM(HuggingfacePreTrainedModel, transformers.generation.utils.GenerationMixin): - def forward( + def inner_forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, @@ -127,3 +129,148 @@ def forward( ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: # Meant to be overridden in derived classes raise NotImplementedError() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values=None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + coordinator_forward: bool = False, + communication_timeout_sec: float = 600.0, + continue_work: bool = True, + ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast | None: + """ + Forward pass compatible with HuggingFace forward. + + Additional arguments: + coordinator_forward (bool): + If True, only the TP group coordinator (rank 0) should call forward; + other ranks must call worker_forward. + If False, all TP group ranks call forward independently and return logits. + communication_timeout_sec (float): Maximum time (in seconds) to wait for the start of + forward or for a stop signal to worker ranks before timing out in worker_forward. + Must match the value passed to worker_forward. + continue_work (bool): Whether to continue processing in a TP group. + Only applies for coordinator_forward=True. + + Notes: + - In coordinator_forward=True mode, forward on rank 0 distributes data to other ranks. + - After processing, the coordinator (rank 0) must call `stop_workers()` before continuing, + to unblock worker_forward on other ranks. + - This mode augments HuggingFace generate with tensor-parallel capability. + """ + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + + if coordinator_forward and distributed.world_group and distributed.tensor_group: + assert distributed.tensor_group.rank() == 0 + assert past_key_values is None and not use_cache + + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec) + + broadcast_optional(input_ids, distributed.tensor_group, 0) + broadcast_optional(attention_mask, distributed.tensor_group, 0) + broadcast_optional(position_ids, distributed.tensor_group, 0) + broadcast_optional(inputs_embeds, distributed.tensor_group, 0) + broadcast_optional(labels, distributed.tensor_group, 0) + + broadcast_object( + (past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work), + distributed.tensor_group, + 0, + ) + + if not coordinator_forward or continue_work: + return self.inner_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + ) + + return None + + def worker_forward( + self, + communication_timeout_sec: float = 600.0, + ): + """ + Run the forward loop on worker ranks in coordinated mode. + + This function must be called on all worker ranks (i.e., all ranks except the + coordinator/leading data-parallel rank). In coordinated mode, the coordinator + rank calls `forward`, which distributes inputs to workers. Each worker then + receives its inputs and runs a forward pass. + + Workers stay in this loop until a stop signal is broadcast, which happens when + the coordinator rank calls `stop_workers`. + + Args: + communication_timeout_sec (float): Maximum time (in seconds) to wait for the + start of a forward call or for a stop signal from the coordinator before + timing out. Must match the value passed to `forward`. + + Notes: + - Coordinator rank: calls `forward` in coordinated mode and later + `stop_workers` to unblock workers. + - Worker ranks: call `worker_forward` once and remain inside the loop, + executing forward passes with broadcasted inputs until a stop signal + is received. + """ + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + assert distributed.world_group and distributed.tensor_group and distributed.tensor_group.rank() != 0 + + while True: + # Some tasks may post-process too slowly, so waiting for the next batch or + # the end of work can exceed the standard 60s timeout. + safe_barrier(distributed.tensor_group, "forward_wait", timeout=communication_timeout_sec) + + input_ids = broadcast_optional(None, distributed.tensor_group, 0) + attention_mask = broadcast_optional(None, distributed.tensor_group, 0) + position_ids = broadcast_optional(None, distributed.tensor_group, 0) + inputs_embeds = broadcast_optional(None, distributed.tensor_group, 0) + labels = broadcast_optional(None, distributed.tensor_group, 0) + + past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, continue_work = ( + broadcast_object(None, distributed.tensor_group, 0) + ) + + if not continue_work: + break + + self.inner_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + ) + + safe_barrier(distributed.world_group, "forward_work_end") + + def stop_workers(self): + distributed: Distributed = self._inference_runner._fast_llm_model.distributed + # On single gpu or no tp, no worker_forward to stop + if distributed.world_group is None or distributed.tensor_group is None: + return + self.forward(coordinator_forward=True, continue_work=False) + safe_barrier(distributed.world_group, "forward_work_end") diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index ade1144d2..0ab64cc9f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -5,7 +5,7 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.core.ops import split_op +from fast_llm.core.ops import gather_op, split_op from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim @@ -273,6 +273,14 @@ def _logits_cross_entropy_forward_backward_split( ) if targets is None: # TODO: Make a proper way of returning the model output. + loss = loss.detach() + if kwargs.get("global_logits"): + if self._vocab_parallel: + loss = gather_op(loss, self._parallel_dim.group, 2) + elif self._sequence_parallel_logits: + loss = gather_op( + loss, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 + ) kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss return None, None else: diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 2f99ae4c3..680d8bfb2 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -32,7 +32,7 @@ class HuggingfaceGPTModelForCausalLM(HuggingfaceBaseModelForCausalLM): # _supports_cache_class = False # _tied_weights_keys = [] - def forward( + def inner_forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, @@ -99,10 +99,15 @@ def forward( else: kwargs["output_hidden_states"] = False + kwargs["global_logits"] = True + self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. - logits = kwargs["logits"] + if kwargs[TransformerKwargs.sequence_first]: + logits = kwargs["logits"].transpose(0, 1) + else: + logits = kwargs["logits"] # TODO: convert hidden state form dict to list to be the same as with HFs hidden_states = None