diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 7821b81c5..a6ebf7a2b 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -202,7 +202,7 @@ def _set_target_inputs( model_input.targets.append(target_input) - def _get_label_counts(self, mask: torch.Tensor): + def _get_label_counts(self, mask: torch.Tensor) -> torch.Tensor: # Count the number of non-masked labels in each document through cumulative sums. mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0) @@ -210,7 +210,6 @@ def _get_label_counts(self, mask: torch.Tensor): labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] # Expand to one entry per token: find each token's document index via the sorted # length cumsum, then look up that document's label count. - # TODO: Document index already computed in `LengthModelInputPreprocessor`. document_index = torch.searchsorted( length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" ) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 29720b90b..2920c1334 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -21,6 +21,16 @@ class ScheduleConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + docs_per_step: int = Field( + default=0, + desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. " + "When >0, each training step dynamically accumulates microbatches until the globally all-reduced " + "document count reaches this value, then triggers the optimizer step. " + "depth_first_micro_batches is ignored when this is set. " + "0 = use depth_first_micro_batches as-is (fixed microbatch count per step).", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) breadth_first_micro_batches: int = Field( default=1, desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.", diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index b2e212946..128b95e8e 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -320,7 +320,8 @@ def _preprocess_data( if context.schedule.phase.is_training else None ) - model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] + n_micro_batches = context.schedule._eff_sequential_micro_batches + model_inputs = [next(data_iterator) for _ in range(n_micro_batches)] model_inputs[0][0].share_batch_data( [model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed ) @@ -336,7 +337,7 @@ def _preprocess_data( extra_kwargs={ "grad_output": grad_output, "micro_batch": micro_batch, - "num_micro_batches": self._config.sequential_micro_batches, + "num_micro_batches": n_micro_batches, "micro_batch_splits": self._config.micro_batch_splits, }, ) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 6f7bf1d95..845b5df82 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -115,15 +115,17 @@ def __init__( batch_meta: list[ModelInput], distributed_config: DistributedConfig, phase: PhaseType, + _depth_first_override: int | None = None, ): super().__init__(config) + self._depth_first_override = _depth_first_override self._multi_stage = multi_stage self._distributed_config = distributed_config self._num_stages = len(self._multi_stage.stages) self._phase = phase self._is_training = self._phase.is_training - if self._config.num_inputs < self._distributed_config.pipeline_parallel: + if self._eff_num_inputs < self._distributed_config.pipeline_parallel: warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. @@ -155,9 +157,25 @@ def __init__( def phase(self) -> PhaseType: return self._phase + @property + def _eff_depth_first(self) -> int: + return ( + self._depth_first_override + if self._depth_first_override is not None + else self._config.depth_first_micro_batches + ) + + @property + def _eff_sequential_micro_batches(self) -> int: + return self._eff_depth_first * self._config.breadth_first_micro_batches + + @property + def _eff_num_inputs(self) -> int: + return self._eff_sequential_micro_batches * self._config.micro_batch_splits + @property def samples_per_batch(self) -> int: - return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel + return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]: return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank]) @@ -189,7 +207,7 @@ def _create_index(self) -> None: Assert.in_range( step.index, 0, - self._config.num_inputs, + self._eff_num_inputs, ) Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i @@ -205,7 +223,7 @@ def _create_index(self) -> None: Assert.custom(all, self._device_steps) # Consistency checks step_map = self._step_map.copy() - for data_index in range(self._config.num_inputs): + for data_index in range(self._eff_num_inputs): for type_ in (StepType.forward, StepType.backward): for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages): assert ( @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]: first_grad_stage += 1 else: first_grad_stage = self._num_stages - for depth_first_micro_batch in range(self._config.depth_first_micro_batches): + for depth_first_micro_batch in range(self._eff_depth_first): for stage in range(self._num_stages): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in range(self._config.micro_batch_splits): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]: for stage in reversed(range(first_grad_stage, self._num_stages)): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in reversed(range(self._config.micro_batch_splits)): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 1ed18c449..77a88377e 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -115,10 +115,12 @@ def setup(self, distributed: Distributed, run: Run) -> None: preprocessing_config = self._multi_stage.get_preprocessing_config( PhaseType.training, self._config.schedule.micro_batch_splits ) + self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size) + self._schedule_cache: dict[int, Schedule] = {} self._schedule = Schedule( config=self._config.schedule, multi_stage=self._multi_stage, - batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size), + batch_meta=self._single_mb_meta, distributed_config=self._config.model.distributed, phase=PhaseType.training, ) @@ -140,6 +142,41 @@ def setup(self, distributed: Distributed, run: Run) -> None: self._is_setup = True + def _get_or_build_schedule(self, n_microbatches: int) -> Schedule: + if n_microbatches not in self._schedule_cache: + bfmb = self._config.schedule.breadth_first_micro_batches + depth_first = n_microbatches // bfmb + self._schedule_cache[n_microbatches] = Schedule( + config=self._config.schedule, + multi_stage=self._multi_stage, + batch_meta=self._single_mb_meta, + distributed_config=self._config.model.distributed, + phase=PhaseType.training, + _depth_first_override=depth_first, + ) + return self._schedule_cache[n_microbatches] + + def _prefetch_to_doc_target(self, data_iterator) -> list: + target = self._config.schedule.docs_per_step + bfmb = self._config.schedule.breadth_first_micro_batches + buffer = [] + total_docs = 0 + while total_docs < target: + mb = next(data_iterator) + mb[0].share_batch_data(mb, self._distributed) + total_docs += mb[0].num_documents_in_batch + buffer.append(mb) + Assert.eq( + len(buffer) % bfmb, + 0, + msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}", + ) + # Reset num_documents_in_batch to the step total on all microbatches + for mb in buffer: + for mi in mb: + mi.num_documents_in_batch = total_docs + return buffer + @abc.abstractmethod def _get_data(self) -> Data: pass @@ -220,12 +257,22 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Data loader hates getting all micro-batches at once. # (Also preprocessing adds overhead) - reduced_losses, update_successful, train_metrics = self._runner.run_step( - train_iterator, - self._schedule, - iteration=self._completed_steps, - return_metrics=is_logging, - ) + if self._config.schedule.docs_per_step > 0: + buffer = self._prefetch_to_doc_target(train_iterator) + step_schedule = self._get_or_build_schedule(len(buffer)) + reduced_losses, update_successful, train_metrics = self._runner.run_step( + iter(buffer), + step_schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) + else: + reduced_losses, update_successful, train_metrics = self._runner.run_step( + train_iterator, + self._schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) # Advanced, skipped, and Nan iterations. if update_successful: diff --git a/fast_llm/functional/triton/grpo_loss.py b/fast_llm/functional/triton/grpo_loss.py index 39d832ccd..709bbc73c 100644 --- a/fast_llm/functional/triton/grpo_loss.py +++ b/fast_llm/functional/triton/grpo_loss.py @@ -137,6 +137,7 @@ def triton_grpo_loss_forward_backward( logits_scale_factor: float = 1.0, num_labels_in_seq: torch.Tensor | None = None, divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) block_size: int | None = None, num_warps: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: @@ -148,6 +149,8 @@ def triton_grpo_loss_forward_backward( n_cols = logits.size(-1) if divisor is None: divisor = n_rows + if grad_divisor is None: + grad_divisor = divisor if block_size is None: block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: @@ -171,7 +174,7 @@ def triton_grpo_loss_forward_backward( grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits backward_kwargs = { "grad_logits_ptr": grad_logits, - "grad_losses": grad_output / divisor, + "grad_losses": grad_output / grad_divisor, "grad_logits_stride_0": grad_logits.stride(-2), "accumulate": accumulate, } diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index bde33f297..6a0bfcfd6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -131,6 +131,13 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.architecture, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + fp32_lm_head: bool = Field( + default=False, + desc="Upcast input and weight to float32 before the lm_head linear. " + "Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs " + "are computed at the same numerical precision, keeping the IS ratio near 1 at init.", + hint=FieldHint.feature, + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 22c750082..eb67cd553 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,7 +22,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.tensor import TensorMeta +from fast_llm.tensor import TensorMeta, accumulate_gradient from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -252,9 +252,17 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if self._config.fp32_lm_head: + input_dtype = input_.dtype + input_ = input_.to(torch.float32) + # detach → requires_grad=False → output_parallel_linear_backward skips weight grad + weight = self.output_weights.detach().to(torch.float32) + else: + weight = self.output_weights + logits, context = output_parallel_linear_forward( input_=input_, - weight=self.output_weights, + weight=weight, bias=None, group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -285,12 +293,26 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) - if grad is not None and self._config.final_logit_softcap is not None: + if not self.training or grad is None: + return sum(losses_) if losses_ else None, None + + if self._config.final_logit_softcap is not None: grad = _softcap_backward(grad, logits, self._config.final_logit_softcap) - return sum(losses_) if losses_ else None, ( - output_parallel_linear_backward(grad, context) if self.training else None - ) + input_grad = output_parallel_linear_backward(grad, context) + if self._config.fp32_lm_head: + # Weight grad was skipped because weight.requires_grad=False; accumulate manually. + # context: (input_, weight, bias, group, sequence_parallel, ...) + saved_input = context[0] + if context[4]: # sequence_parallel + from fast_llm.core.ops import gather_op + + saved_input = gather_op(saved_input, context[3], dim=0) + grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2)) + accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype)) + input_grad = input_grad.to(input_dtype) + + return sum(losses_) if losses_ else None, input_grad def get_loss_definitions(self) -> list[LossDef]: return [ diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 8e9594534..c514c2a5f 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -209,22 +209,18 @@ class GRPOMetricsLevel(enum.StrEnum): with_entropy = "with_entropy" -@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) -class LanguageModelGRPOLossConfig(LanguageModelLossConfig): +@config_class() +class LanguageModelPolicyGradientLossConfig(LanguageModelLossConfig): + """Shared base for policy-gradient losses (GRPO, GSPO).""" - _abstract: typing.ClassVar[bool] = False + _abstract: typing.ClassVar[bool] = True epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") - use_triton: bool | None = Field( - default=None, - desc="Enable triton implementation. Default: use if available.", - hint=FieldHint.expert, - ) metrics: GRPOMetricsLevel = Field( default=GRPOMetricsLevel.none, desc=( - "Additional GRPO metrics to log. " + "Additional policy-gradient metrics to log. " "`basic`: per-token ratio, KL, and advantage statistics. " "`with_entropy`: also log per-token entropy. " "Not supported with pipeline_parallel > 1." @@ -232,6 +228,23 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): hint=FieldHint.feature, ) + @property + def loss_class(self) -> "type[LanguageModelPolicyGradientLoss]": + raise NotImplementedError() + + +@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) +class LanguageModelGRPOLossConfig(LanguageModelPolicyGradientLossConfig): + """Group-Relative Policy Optimization: per-token IS-ratio clipping.""" + + _abstract: typing.ClassVar[bool] = False + + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) + @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 4bbaeb581..019697228 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -14,6 +14,7 @@ GRPOMetricsLevel, LanguageModelGRPOLossConfig, LanguageModelLossKwargs, + LanguageModelPolicyGradientLossConfig, ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.utils import Assert @@ -33,7 +34,16 @@ class GRPOMetrics(typing.NamedTuple): entropy: torch.Tensor | None -class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): +class LanguageModelPolicyGradientLoss[ConfigType: LanguageModelPolicyGradientLossConfig]( + LanguageModelLoss[ConfigType] +): + """Shared scaffolding for policy-gradient losses (GRPO, GSPO). + + Subclasses set `self._forward` to the actual kernel function in `__init__` and implement + `_forward_backward` to call it. Shared logic — divisor selection, loss/metric registration — + lives here. + """ + def __init__( self, config: ConfigType, @@ -66,51 +76,51 @@ def __init__( distributed_config.pipeline_parallel, ) - def _forward_backward( + def _compute_divisors(self, kwargs: dict[str, typing.Any]) -> tuple[float | int, float | int | None]: + return self._get_label_count(kwargs), None + + def _shared_kernel_kwargs( self, - logits: "torch.Tensor", kwargs: dict[str, typing.Any], - losses: dict | None = None, - split_index: int = 0, - grad_logits: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - if TritonConfig.enabled(logits.device, self._config.use_triton): - from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward - - fn = triton_grpo_loss_forward_backward - else: - fn = fused_grpo_loss_forward_backward - loss, grad, new_logprobs_mean = fn( - logits, - self._get_labels(kwargs, split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), - grad_logits=grad_logits, - grad_output=self._get_grad_output(kwargs), - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._logits_scale_factor, - num_labels_in_seq=( + losses: dict | None, + split_index: int, + grad_logits: torch.Tensor | None, + divisor: float | int, + grad_divisor: float | int | None, + ) -> dict[str, typing.Any]: + return { + "grad_logits": grad_logits, + "grad_output": self._get_grad_output(kwargs), + "group": self._parallel_dim.group if self._vocab_parallel else None, + "epsilon_low": self._config.epsilon_low, + "epsilon_high": self._config.epsilon_high, + "logits_scale_factor": self._logits_scale_factor, + "num_labels_in_seq": ( None if losses is None else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), - divisor=self._get_label_count(kwargs), - ) + "divisor": divisor, + "grad_divisor": grad_divisor, + } + def _finalize_loss( + self, + new_logprobs_mean: torch.Tensor | None, + logits: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict | None, + split_index: int, + ) -> None: if new_logprobs_mean is not None: new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] self._register_loss( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) - # Skip the extra softmax pass when there is nothing to register. if losses is not None and self._config.metrics != GRPOMetricsLevel.none: self._register_extra_metrics(logits, kwargs, losses, split_index) - return loss, grad - def _register_extra_metrics( self, logits: torch.Tensor, @@ -187,9 +197,7 @@ def get_loss_definitions(self) -> list[LossDef]: defs.append(LossDef(f"{self._name}_entropy")) return defs - def get_preprocessing_config( - self, - ) -> dict[str, typing.Any]: + def get_preprocessing_config(self) -> dict[str, typing.Any]: return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} @functools.cached_property @@ -197,6 +205,43 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" +class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): + """GRPO: per-token IS-ratio clipping.""" + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + **kwargs: typing.Any, + ): + super().__init__(config, distributed_config, **kwargs) + if TritonConfig.enabled(torch.device("cuda"), config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + self._forward = triton_grpo_loss_forward_backward + else: + self._forward = fused_grpo_loss_forward_backward + + def _forward_backward( + self, + logits: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict | None = None, + split_index: int = 0, + grad_logits: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + divisor, grad_divisor = self._compute_divisors(kwargs) + loss, grad, new_logprobs_mean = self._forward( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + **self._shared_kernel_kwargs(kwargs, losses, split_index, grad_logits, divisor, grad_divisor), + ) + self._finalize_loss(new_logprobs_mean, logits, kwargs, losses, split_index) + return loss, grad + + @torch.compile def compute_grpo_metrics( logits: torch.Tensor, # (*batch, vocab_local) @@ -268,10 +313,13 @@ def fused_grpo_loss_forward_backward( torch.Tensor | None ) = None, # (*batch,) — response-span length broadcast per token, 0 for non-response divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: if divisor is None: divisor = logits.shape[:-1].numel() - grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor + if grad_divisor is None: + grad_divisor = divisor + grad_output = None if grad_output is None else grad_output / grad_divisor * logits_scale_factor loss_mask = target >= 0 logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 034953b50..53e6e6e88 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -39,6 +39,8 @@ def __init__( self._vocab_parallel = distributed_config.tensor_parallel > 1 and vocab_parallel self._sequence_parallel = distributed_config.sequence_tensor_parallel and not self._vocab_parallel self._parallel_dim = distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._sdp_dim = distributed_config.get_distributed_dim(DistributedDimNames.sequence_data) + self._sdp_active = distributed_config.sequence_data_parallel > 1 def forward_backward( self, diff --git a/tests/layers/test_docs_per_step.py b/tests/layers/test_docs_per_step.py new file mode 100644 index 000000000..477109931 --- /dev/null +++ b/tests/layers/test_docs_per_step.py @@ -0,0 +1,204 @@ +""" +Unit tests for docs_per_step. + +Covers: + 1. Divisor scaling in fused_grpo_loss_forward_backward + 2. Schedule._eff_depth_first / _eff_sequential_micro_batches / _eff_num_inputs properties + 3. Trainer._prefetch_to_doc_target accumulation logic +""" + +import dataclasses +import types + +import pytest +import torch + +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.language_model.loss.grpo import fused_grpo_loss_forward_backward + +device = "cuda" if torch.cuda.is_available() else "cpu" +_atol = 1e-4 if device == "cuda" else 1e-5 + + +# --------------------------------------------------------------------------- +# 1. Divisor-scaling correctness in raw kernels +# --------------------------------------------------------------------------- + + +def test_grpo_divisor_scales_loss(): + """Halving the divisor should double the loss.""" + torch.manual_seed(10) + n_tok, vocab = 16, 32 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + + d1 = float(n_tok) + d2 = float(n_tok) * 2 + + loss1, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d1) + loss2, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d2) + + assert ( + abs(loss1.item() - 2.0 * loss2.item()) < _atol * 10 + ), f"Expected loss(d1) ≈ 2*loss(d2), got {loss1.item():.6f} vs {2*loss2.item():.6f}" + + +# --------------------------------------------------------------------------- +# 2. Schedule._eff_* properties +# --------------------------------------------------------------------------- + + +def _make_bare_schedule(depth_first: int, breadth_first: int, splits: int, override: int | None) -> Schedule: + """Create a Schedule with __init__ bypassed to test the _eff_* properties only.""" + config = ScheduleConfig( + depth_first_micro_batches=depth_first, + breadth_first_micro_batches=breadth_first, + micro_batch_splits=splits, + ) + sched = object.__new__(Schedule) + # Minimal attributes used by the three _eff_* properties. + object.__setattr__(sched, "_config", config) + object.__setattr__(sched, "_depth_first_override", override) + # samples_per_batch also needs _distributed_config.batch_data_parallel + fake_distributed = types.SimpleNamespace(batch_data_parallel=1) + object.__setattr__(sched, "_distributed_config", fake_distributed) + return sched + + +def test_schedule_eff_properties_no_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=None) + assert sched._eff_depth_first == 4 + assert sched._eff_sequential_micro_batches == 8 # 4 * 2 + assert sched._eff_num_inputs == 24 # 8 * 3 + assert sched.samples_per_batch == 8 # 8 * dp=1 + + +def test_schedule_eff_properties_with_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=7) + assert sched._eff_depth_first == 7 # override wins + assert sched._eff_sequential_micro_batches == 14 # 7 * 2 + assert sched._eff_num_inputs == 42 # 14 * 3 + assert sched.samples_per_batch == 14 # 14 * dp=1 + + +def test_schedule_eff_properties_override_equals_config(): + """Override equal to config value → same result as no override.""" + sched_no = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=None) + sched_yes = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=3) + assert sched_no._eff_depth_first == sched_yes._eff_depth_first + assert sched_no._eff_sequential_micro_batches == sched_yes._eff_sequential_micro_batches + assert sched_no._eff_num_inputs == sched_yes._eff_num_inputs + + +def test_schedule_samples_per_batch_uses_eff(): + """samples_per_batch should scale with _eff_sequential, not config.sequential.""" + sched = _make_bare_schedule(depth_first=2, breadth_first=2, splits=1, override=5) + # Config says depth_first=2 → sequential=4; override=5 → eff_sequential=10 + assert sched._eff_sequential_micro_batches == 10 + assert sched.samples_per_batch == 10 # dp=1 + + +# --------------------------------------------------------------------------- +# 3. _prefetch_to_doc_target accumulation logic +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _FakeMicrobatch: + """Stub for a single split of one microbatch.""" + + num_documents: int + num_documents_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, inputs, distributed): + """Mimic TokenModelInput.share_batch_data with group=None (single process).""" + if inputs[0].num_documents_in_batch is None: + total = sum(inp.num_documents for inp in inputs) + for inp in inputs: + inp.num_documents_in_batch = total + + +def _fake_iterator(doc_counts: list[int]): + """Yield [_FakeMicrobatch(n)] for each n in doc_counts.""" + for n in doc_counts: + yield [_FakeMicrobatch(num_documents=n)] + + +class _StubTrainer: + """Concrete stub that exposes only the interface _prefetch_to_doc_target needs.""" + + # Borrow the method directly so it runs against this stub's attributes. + from fast_llm.engine.training.trainer import Trainer as _Trainer + + _prefetch_to_doc_target = _Trainer._prefetch_to_doc_target + + +def _make_fake_trainer(docs_per_step: int, bfmb: int = 1): + """Create a _StubTrainer with the attributes _prefetch_to_doc_target reads.""" + schedule_cfg = types.SimpleNamespace( + docs_per_step=docs_per_step, + breadth_first_micro_batches=bfmb, + ) + config = types.SimpleNamespace(schedule=schedule_cfg) + distributed = types.SimpleNamespace(batch_data_group=None) + + trainer = _StubTrainer() + trainer._config = config + trainer._distributed = distributed + return trainer + + +def test_prefetch_stops_at_target(): + """Buffer should stop growing once cumulative docs ≥ docs_per_step.""" + trainer = _make_fake_trainer(docs_per_step=6, bfmb=1) + # Each microbatch has 2 docs; need ≥6 → expect 3 microbatches + it = _fake_iterator([2, 2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + + assert len(buffer) == 3, f"Expected 3 microbatches, got {len(buffer)}" + + +def test_prefetch_resets_num_documents_in_batch(): + """After the call, every microbatch input has num_documents_in_batch = step total.""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + # 3 docs, 3 docs → total=6 (overshoots 5, stops after 2nd) + it = _fake_iterator([3, 3, 3]) + buffer = trainer._prefetch_to_doc_target(it) + + step_total = sum(mb[0].num_documents for mb in buffer) + for mb in buffer: + for mi in mb: + assert ( + mi.num_documents_in_batch == step_total + ), f"Expected num_documents_in_batch={step_total}, got {mi.num_documents_in_batch}" + + +def test_prefetch_overshoot_is_included(): + """A microbatch that pushes the total over the target IS included (not dropped).""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + it = _fake_iterator([4, 4]) # 4 < 5, then 8 ≥ 5 → 2 microbatches + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2 + assert buffer[-1][0].num_documents_in_batch == 8 # step total = 4+4 + + +def test_prefetch_divisibility_check(): + """Raises when fetched count is not divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # Each microbatch has 5 docs → only 1 mb needed, but 1 % 2 != 0 + it = _fake_iterator([5, 5, 5]) + with pytest.raises(Exception): + trainer._prefetch_to_doc_target(it) + + +def test_prefetch_exact_divisibility(): + """No error when fetched count is exactly divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # 2 docs each → need ≥4 → fetch 2 microbatches → 2 % 2 == 0 + it = _fake_iterator([2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2