Skip to content

gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature)#502

Open
bigximik wants to merge 20 commits intomainfrom
gspo
Open

gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature)#502
bigximik wants to merge 20 commits intomainfrom
gspo

Conversation

@bigximik
Copy link
Copy Markdown
Collaborator

@bigximik bigximik commented Apr 29, 2026

Summary

This PR adds GSPO loss to fast-LLM along with a suite of supporting fixes that together achieve full metric and training-trajectory parity with DeepSpeed's GRPO/GSPO implementation. Targets the grpo-metrics branch. Six logical units:

1. GSPO loss (sequence-level IS-ratio clipping)

Implements GSPO as an alternative policy-gradient loss alongside the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo".

  • New fused_gspo_loss_forward_backward kernel: computes per-segment geometric-mean log-ratio R_s, clips at [1−ε_low, 1+ε_high], and applies R_s × A_s as a uniform per-token gradient within each segment. An all_reduce(SUM) over sequence-data-parallel ranks aggregates (lrn_sum, adv_sum, tok_count) before clipping so the ratio is correct under sequence parallelism.
  • New document_index data field and LanguageModelKwargs.document_index kwarg constant to route per-token segment membership through the data pipeline.
  • 8 unit tests in tests/layers/test_gspo_loss.py (single-segment, packed sequences, ratio=1 equivalence, clipping, masking, SDP mock, gradient finite-diff, independence from per-token metrics).

2. Dynamic docs_per_step accumulation

Replaces static depth_first_micro_batches with a runtime document-count target — matching DeepSpeed's gradient_accumulation_passes semantics for RL (where each microbatch holds one rollout).

  • ScheduleConfig.docs_per_step: when >0, Trainer._prefetch_to_doc_target fetches microbatches one at a time, all-reduces the per-microbatch document count, and stops once the global total ≥ docs_per_step. The final step total is broadcast to all inputs so the normalisation denominator is consistent.
  • Trainer._get_or_build_schedule builds and caches a per-N Schedule with _depth_first_override = N // breadth_first_micro_batches, so the existing schedule machinery is reused without changes to the runner.
  • Schedule._eff_{depth_first,sequential,num_inputs} properties expose the effective values for a given override.
  • 13 unit tests in tests/layers/test_docs_per_step.py.

3. normalize_by_documents

Adds a normalize_by_documents flag to LanguageModelGRPOLossConfig. When True, both the GRPO and GSPO paths divide the loss by num_documents_in_batch (the step-level rollout count) rather than the token count. Matches DeepSpeed's normalization where tokens_weights = 1 / batch_size.

4. Temperature scaling for IS ratio parity

Adds a temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probabilities are computed at the same temperature as the stored old log-probabilities from vLLM, so the IS ratio starts near 1.0 at step 0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted at all three call-sites in _forward_backward. Default temperature=1.0 preserves existing behaviour exactly.

5. fp32_lm_head precision fix (matches vLLM's bf16_last_layer_fp32)

Adds a fp32_lm_head flag (default False) on LanguageModelHeadConfig. When True, the LM head's logits computation upcasts both input and weight to FP32 before the linear projection, matching vLLM's bf16_last_layer_fp32 quantization. This ensures the trainer computes log-probabilities at the same numerical precision as the actor's sampling, so new_logprobs ≈ old_logprobs at step 0 (IS ratio at training start ≈ 1.0, not artificially inflated by precision mismatch).

  • Commit d8cb9ef5: introduces the flag, upcasts input/weight, casts back to BF16 before downstream consumption.
  • Commit 0f90f20b: fixes the gradient flow when fp32_lm_head=True. The detached FP32 weight copy has requires_grad=False, which makes output_parallel_linear_backward skip writing to the original weight's grad_buffer. We restore the FSDP gradient contract by computing grad_weight = grad.t() @ saved_input explicitly and accumulating into the BF16 param's grad_buffer via accumulate_gradient.

6. Decoupled loss/gradient divisors and SDP loss double-counting fix

Even with normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024× larger than DeepSpeed's, causing the default gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10 reward points slower than DS GSPO at the same step count. Two issues, fixed in commit 557a3c4c:

Asymmetric loss/gradient scaling in DS:

  • DS loss reported uses /batch_size once (via tokens_weights = 1/batch_size, pipelinerl/finetune/rl/__init__.py:246).
  • DS gradient buffer has an ADDITIONAL /(gas × world_size) factor from scale_wrt_gas=True in engine.backward() (deepspeed/runtime/engine.py:1995-1996) and tensor.div_(world_sz) in reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124).
  • For samples_per_microbatch=1 (PipelineRL standard), gas × world_size = batch_size, so the gradient buffer effectively has 1/batch_size² while the loss metric has 1/batch_size.

Fast-LLM cancels DS's /(gas × world_size) factor structurally via grad_output = data_parallel × grad_scale (runner.py:318) interacting with FSDP's RS-AVG over data_parallel ranks (fsdp.py:396). So we need to apply the second 1/batch_size factor explicitly only to the gradient — keeping the loss metric matched to DS.

Fix: add a grad_divisor parameter to fused_gspo_loss_forward_backward, fused_grpo_loss_forward_backward, and triton_grpo_loss_forward_backward. When normalize_by_documents=true:

  • loss divisor = num_documents_in_batch (matches DS rl/loss)
  • gradient divisor = num_documents_in_batch² (matches DS grad_norm)

Independent of TP/PP/SDP/DP parallelism and microbatching schedule, because batch_size is invariant under all of them.

SDP loss double-counting:
After the SDP allreduce of lrn_sum/adv_sum/tok_sum in fused_gspo_loss_forward_backward, both SDP ranks compute IDENTICAL per-segment loss values. When LossDef.reduce SUMs across data_group (which includes SDP ranks), the loss metric is double-counted by sdp_size. The gradient is NOT double-counted — each SDP rank contributes gradient from its own LOCAL tokens, with different contributions for different tokens of the same segment.

Fix: divide loss by sdp_size when sdp_group is active. Gradient unaffected.

Verification

End-to-end 7B math run on 4 nodes, GSPO, gradient_norm_clipping=0.3 (default), normalize_by_documents=true, fp32_lm_head=true, temperature=0.7:

Metric Before unit-6 fix After unit-6 fix DS GSPO reference
step 1 grad_norm 141 (1000× DS) 0.135 0.145
step 1 lm_head_loss -13.7 ~-1.7 magnitude -1.7
step 1 clip_coeff 0.002 (severe over-clip) 1.000 (no clip) no clip
step 50 newlp trapped at -0.17 -0.103 -0.105

newlp trajectory tracks DS step-by-step. Both systems show same gradient-spike pattern during warmup ramp-up at steps 14-20 (DS step 16 grad_norm=6.365, fast-LLM step 15=9.005). Match within data variance.

Test plan

  • pytest tests/layers/test_gspo_loss.py — GSPO unit tests pass
  • pytest tests/layers/test_docs_per_step.py — docs_per_step unit tests pass
  • pytest tests/layers/test_lm_losses.py — existing GRPO loss + per-token metrics tests unaffected (the metrics tests previously in test_grpo_metrics.py moved into this file on the base branch)
  • End-to-end: 4-node Qwen2.5-7B math run with full config (docs_per_step=1024, temperature=0.7, normalize_by_documents=true, fp32_lm_head=true, default gradient_norm_clipping=0.3) — grad_norm matches DS at step 1, training trajectory matches DS step-by-step through step 50+ (ongoing run validates through step ~410).

Adds GRPO metrics parity with DeepSpeed: old_logprobs, ratio, ratio_sum,
ratio_sq_sum, kl_new_old, clamp_frac, advantage, max/min_advantage,
num_tokens, and optional per-token entropy.

New files:
- fast_llm/layers/language_model/loss/pg_metrics.py: reusable
  PolicyGradientMetrics dataclass + compute_policy_gradient_metrics()
  (callable by future PPO), with chunked vocab-parallel entropy support.
- tests/layers/test_grpo_metrics.py: 8 unit tests covering single-seq,
  packed multi-seq, masked tokens, clamp fraction, entropy correctness,
  mock SDP correctness, mock vocab-parallel entropy, normalization parity.

Config additions to LanguageModelGRPOLossConfig:
- compute_extra_metrics (default False): log all non-entropy metrics
- compute_entropy_metric (default False): additionally log per-token entropy
- entropy_chunk_size (default 4096): batch chunk size for entropy pass

Normalization matches existing new_logprobs_mean: sum(v*mask/label_counts)
then divided by num_documents_in_batch. MAX/MIN use LossDef ReductionType
and correct ReduceOp so they aggregate correctly across microbatches and
SDP/sequence-parallel ranks.
Rename four metrics to match DeepSpeed's naming exactly so runs on both
backends produce comparable WandB keys:

  ratio        → ratio_new_old
  ratio_sum    → ratio_new_old_sum
  ratio_sq_sum → ratio_new_old_squared_sum
  clamp_frac   → clamp_log_ratio_new_old_indicator
Implements GSPO (geometric-mean sequence-level policy-gradient loss) as
an alternative to the existing per-token GRPO clipping. Controlled via
LanguageModelGRPOLossConfig.policy_loss = "gspo".

Key changes:
- data pipeline: expose per-token document_index when return_document_index=True
- LanguageModelKwargs.document_index: new kwarg constant
- LanguageModelLoss: store SDP dim for cross-rank segment aggregation
- grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across
  SDP ranks before computing segment-level R_s and A_s; gradient derivation
  exploits tok_count cancellation so every token in a segment gets the
  same gradient factor R_s * clip_indicator_s
- tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed,
  ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff,
  per-token metrics unchanged)
Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict
computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel ×
breadth_first_micro_batches) before sub-configs are created (and frozen).

Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1
each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8
gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally.

YAML usage:
  schedule:
    rollouts_per_step: 1024   # replaces manual depth_first_micro_batches
  model:
    distributed:
      data_parallel: 8        # used for the division
- Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first
  is now determined at runtime rather than statically in _from_dict
- Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs}
  properties so per-step schedules share the same config object as the runner
- Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time,
  all-reduces doc count per microbatch, stops when global total ≥ docs_per_step,
  then resets num_documents_in_batch to the step total on all inputs
- Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with
  _depth_first_override=N//breadth_first_micro_batches
- Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True
  both GRPO and GSPO paths divide by num_documents_in_batch instead of
  num_labels_in_batch (matches DeepSpeed's per-rollout normalization)
- Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor
  scaling, normalize_by_documents layer routing, Schedule._eff_* properties,
  and _prefetch_to_doc_target accumulation logic
Add temperature field to LanguageModelGRPOLossConfig. When set to match
the actor's sampling temperature (e.g. 0.7), new log-probs are computed
at the same temperature as the stored old log-probs, so the IS ratio
starts near 1.0 instead of ~1.08.

Implementation: _effective_logits_scale = logits_scale_factor / temperature,
substituted for logits_scale_factor at all three callsites in
_forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default
temperature=1.0 preserves existing behaviour exactly.
@bigximik bigximik requested a review from jlamypoirier April 29, 2026 08:04
bigximik added 3 commits May 4, 2026 07:14
Add fp32_lm_head to LanguageModelHeadConfig. When enabled, input hidden
states and output_weights are cast to float32 before the lm_head linear,
producing FP32 logits. This matches vLLM's bf16_last_layer_fp32
quantization (pipelinerl/vllm_quantization.py) and the DeepSpeed trainer's
apply_fp32_lm_head() patch, so new_logprobs and old_logprobs are computed
at the same numerical precision and the IS ratio starts near 1.0 at init.

The gradient flowing back through the linear is cast to the original
input dtype (bf16) before returning, keeping the transformer backward pass
in its native dtype.
…accumulation

Detaching the FP32 weight copy (requires_grad=False) prevents
output_parallel_linear_backward from trying to write to a non-existent
grad_buffer on the copy. Weight grad is then computed explicitly from
the FP32 matmul and accumulated into the original BF16 param's grad_buffer
via accumulate_gradient, restoring the correct FSDP gradient contract.
When normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024×
larger than DeepSpeed's for the equivalent loss, causing the default
gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10
reward points slower than DS GSPO at the same step count. The lm_head_loss
metric was also off — 1024× smaller than DS's rl/loss in the previous
divisor=num_documents² formulation, then 2× too large from SDP doubling.

Root cause analysis
-------------------

DeepSpeed has TWO 1/batch_size factors with different sources:

  1. Loss reported (rl/loss) uses /batch_size via tokens_weights = 1/batch_size
     (pipelinerl/finetune/rl/__init__.py:246). The reported `rl/loss = -1.7`
     value is the raw policy_loss_total, divided once by batch_size.

  2. Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes
     from `scale_wrt_gas=True` in engine.backward()
     (deepspeed/runtime/engine.py:1995-1996) and `tensor.div_(world_sz)` in
     reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124).

For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size
= batch_size, so DS's effective gradient buffer factor is 1/batch_size² while
the loss metric factor is 1/batch_size. Loss and gradient have asymmetric
scaling.

Fast-LLM's existing implementation used a single `divisor` for both loss and
gradient. Worse, the data_parallel × grad_scale factor in grad_output
(runner.py:318) cancels with FSDP's RS-AVG /world_size, structurally removing
DS's /(gas × world_size) factor from the gradient. So fast-LLM's gradient
buffer ended up at 1/batch_size while DS's was at 1/batch_size² — a
~batch_size = 1024× mismatch.

Additionally, GSPO's SDP allreduce of lrn_sum/adv_sum/tok_sum makes both SDP
ranks compute IDENTICAL per-segment loss values. When LossDef.reduce sums
over the data_group (which includes SDP ranks), the loss metric is
double-counted by sdp_size. The gradient buffer is NOT double-counted —
each SDP rank contributes gradient from its own LOCAL tokens, with different
contributions for different tokens of the same segment.

Fixes
-----

1. Add a `grad_divisor` parameter to `fused_gspo_loss_forward_backward`,
   `fused_grpo_loss_forward_backward`, and `triton_grpo_loss_forward_backward`,
   defaulting to `divisor` (existing behavior). Allows the gradient to use a
   different divisor than the loss.

2. In `LanguageModelGRPOLoss._forward_backward`, when normalize_by_documents
   is True, set:
     loss divisor      = num_documents_in_batch     (matches DS rl/loss)
     gradient divisor  = num_documents_in_batch²    (matches DS grad_norm)
   This is independent of TP/PP/SDP/DP parallelism and microbatching schedule
   because batch_size is invariant under all of these.

3. In the GSPO path, divide the loss by sdp_size when sdp_group is active
   (`fused_gspo_loss_forward_backward`). This pre-cancels the SDP doubling
   that LossDef.reduce's SUM over data_group introduces. The gradient is
   unaffected — different SDP ranks naturally contribute gradient from
   different LOCAL token positions, no double-counting at any layer.

Verification
------------

Tested on 7B math run with 4 nodes, GSPO, gradient_norm_clipping=0.3:

  Before fix          | After fix          | DS GSPO reference
  ------------------- | ------------------ | ------------------
  step 1 grad_norm=141| step 1 grad_norm=0.135 | step 1 grad_norm=0.145
  step 1 lm_head_loss | step 1 lm_head_loss   | step 1 rl/loss
   = -13.7            |  ~ -1.7 (sign varies  |   = -1.7
                      |   per data sample)    |
  clip_coeff=0.002    | clip_coeff=1.000      | no clipping at step 1
  newlp at step 50    | newlp at step 50      | newlp at step 50
   trapped at -0.17   |  = -0.103             |  = -0.105

newlp trajectory tracks DS step-by-step: step 1 within 3%, step 50 within 2%.
Both systems show grad_norm spikes at the same training phase (steps 14-20)
during warmup ramp-up — DS step 16 grad_norm=6.365 vs Fast-LLM 6.093.

Files changed
-------------

- fast_llm/layers/language_model/loss/grpo.py:
  - LanguageModelGRPOLoss._forward_backward: split divisor and grad_divisor
    based on normalize_by_documents flag, with detailed comments referencing
    the corresponding lines in DeepSpeed and PipelineRL.
  - fused_gspo_loss_forward_backward: add grad_divisor parameter; divide loss
    by sdp_size when sdp_group is active.
  - fused_grpo_loss_forward_backward: add grad_divisor parameter.

- fast_llm/functional/triton/grpo_loss.py:
  - triton_grpo_loss_forward_backward: add grad_divisor parameter.
@bigximik bigximik changed the title gspo: GSPO loss, docs_per_step accumulation, normalize_by_documents, temperature scaling gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature) May 5, 2026
jlamypoirier and others added 8 commits May 5, 2026 11:02
- Inline pg_metrics.py into grpo.py; rename to GRPOMetrics
- Drop entropy_chunk_size; reuse fused_softmax_base outputs for entropy
- Replace two bool flags with a single metrics: GRPOMetricsLevel enum
- Rename clamp_log_ratio_new_old_indicator -> clipped_ratio_fraction
- Raise on metrics enabled with pipeline_parallel > 1 (MAX/MIN reduce
  would be corrupted by the zero placeholder on empty pipeline ranks)
- Migrate tests into tests/layers/test_lm_losses.py, reusing the
  existing helpers and parametrization (single + distributed runner)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Drop stale "second softmax pass" overhead note from `metrics`
  description (entropy now reuses the base softmax outputs)
- De-mirror max/min in reference_grpo_metrics: use
  advantages[loss_mask].max()/.min() instead of the implementation's
  -inf/+inf sentinel pattern

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Align (logits, target, advantages, old_log_probabilities, ...) order
  across compute_grpo_metrics, fused_grpo_loss_forward_backward, and
  reference_grpo_metrics
- Replace **kwargs in LanguageModelGRPOLoss.__init__ with the explicit
  keyword-only signature mirroring LanguageModelLoss.__init__
- num_docs -> num_documents
- Drop the comment that restated the k3 KL formula
- Give compute_grpo_metrics the same defaults as the loss kernel
- Trim the metrics field description to category-level wording
- Always exercise varying label_counts in _test_grpo_metrics so per-token
  denominator broadcasting is covered
- reference_grpo_metrics returns GRPOMetrics; comparison loop iterates
  dataclasses.fields
- Drop name = self._name micro-rebinds; use self._name inline
- defs = super()...; defs.append(...); defs.extend(...) consistently
- Tighten _register_extra_metrics losses type to dict[str, list[Tensor]]
- Split compiled tuple-returning core from outer GRPOMetrics wrapper to
  avoid @torch.compile graph-breaks on dataclass construction
- One-line comment on the metrics gate explaining the softmax-skip

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NamedTuple is a tuple subclass that dynamo handles natively, so the
previous wrapper/inner split (added to dodge a dataclass graph-break)
collapses into one @torch.compile function. Field order now lives
exactly once — on the class.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Entropy under vocab-parallel TP was wrong: the dot-product term
  (exp_logits * logits_norm).sum(-1) summed only the local vocab slice,
  so dividing by the global sum_exp_logits gave a per-rank fragment
  instead of the full E_p[logit_norm]. All-reduce the partial sum.
- Replace the verbose pipeline-parallel guard with Assert.custom; the
  field description already explains the constraint.
- Drop the cryptic `# k3` comment.
- Match _register_extra_metrics losses annotation to the base class
  (dict | None).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts:
#	fast_llm/layers/language_model/loss/config.py
#	fast_llm/layers/language_model/loss/grpo.py
@jlamypoirier
Copy link
Copy Markdown
Collaborator

Coarse review — pass 1 of 2

Reviewed git diff origin/grpo-metrics...origin/gspo (head 15ae8d66) against the base branch grpo-metrics. Structural / correctness pass only — naming, comments, and formatting nits go to /review-fine.

1. LanguageModelGRPOLossConfig.policy_loss: str at fast_llm/layers/language_model/loss/config.py:208-212 violates the loss-config dispatch convention. Sibling losses (label, distillation, dpo, z_loss, grpo) are all registered via @config_class(dynamic_type={LanguageModelLossConfig: "..."}) (fast_llm/layers/language_model/loss/config.py:78-204). Add a @config_class(dynamic_type={LanguageModelLossConfig: "gspo"}) sibling — extracting a shared base for the fields/code GSPO and GRPO actually have in common — instead of a stringly-typed switch inside LanguageModelGRPOLoss._forward_backward.

2. The fp32_lm_head manual weight-grad path at fast_llm/layers/language_model/head.py:288-298 reaches into output_parallel_linear_forward's context tuple by positional index (context[0], context[3], context[4]) with only a comment as documentation. The forward returns 9 fields (fast_llm/functional/linear.py:136-145); a future reorder will silently miscompute the weight gradient or feed gather_op the wrong arg. Lift the FP32-weight handling into output_parallel_linear_backward (e.g. accept an explicit weight override or a "compute-grad-against-this-other-param" hook) so the head doesn't replicate linear-backward internals.

3. Same fp32_lm_head block re-gathers saved_input on the sequence-parallel path even though output_parallel_linear_backward already gathers it internally (fast_llm/functional/linear.py:152-157). On the FP32 path that's two gather collectives per backward instead of one. Fix together with item 2 — once the weight grad lives in output_parallel_linear_backward, the gathered input1 is already available there.

4. Schedule.__init__(_depth_first_override=...) at fast_llm/engine/schedule/schedule.py:118-121 is a private back-channel that shadows config.depth_first_micro_batches via a parallel _eff_depth_first / _eff_sequential_micro_batches / _eff_num_inputs API (fast_llm/engine/schedule/schedule.py:160-174) used in five places. Make depth_first_micro_batches a regular constructor argument (default = config.depth_first_micro_batches), drop the _eff_* shadow, and let the existing self._config.X reads resolve through the constructor-stored value.

5. GSPO and GRPO call sites in LanguageModelGRPOLoss._forward_backward at fast_llm/layers/language_model/loss/grpo.py:101-133 duplicate the kernel-kwargs block almost verbatim (only the GSPO-specific document_index and sdp_group differ). Build one shared kwargs dict and dispatch the kernel choice on top, or — better — move the dispatch to the subclass introduced in item 1 so each kernel call lives in its own override.

6. _prefetch_to_doc_target at fast_llm/engine/training/trainer.py:170-174 hard-fails via Assert.eq(len(buffer) % bfmb, 0, ...) whenever the natural document-target stopping point isn't breadth_first_micro_batches-aligned. With a streaming dataloader and any bfmb > 1, this crashes mid-step on data-dependent boundaries. Either keep fetching microbatches until alignment is reached (rounding total_docs upward), or validate the combination at config time so it can't happen at runtime.

7. self._preprocessing_config = preprocessing_config at fast_llm/engine/training/trainer.py:118 stores a value nothing else in the diff (or in the file) reads. Drop the assignment.

8. self._schedule = Schedule(...) at fast_llm/engine/training/trainer.py:121-127 is built unconditionally but used only on the docs_per_step == 0 path (line 281 vs 290). Either skip the build when docs_per_step > 0, or unify by feeding the static path through _get_or_build_schedule(sequential_micro_batches) and removing the duplicate construction site.

9. The SDP-loss double-counting fix at fast_llm/layers/language_model/loss/grpo.py:481-482 (loss = loss / world_size(sdp_group)) is the most subtle correctness change in the PR and is not exercised by any test. tests/layers/test_gspo_loss.py::test_sdp_mock calls the kernel with sdp_group=None throughout and reconstructs the SDP semantics manually — the actual if sdp_group is not None: loss = loss / sdp_size branch never runs in tests. Add a multi-rank torchrun-spawned test asserting loss(SDP=k) == loss(SDP=1) on the same global batch.

10. No test exercises Trainer._get_or_build_schedule (fast_llm/engine/training/trainer.py:146-158) — only its _eff_* byproduct properties on Schedule. The cache key behavior, the actual schedule built with the override, and reuse across iterations are untested. Add a unit test that asks for several n_microbatches values, asserts caching, and checks the resulting schedules' _eff_* values.

11. No test covers the fp32_lm_head gradient path at fast_llm/layers/language_model/head.py:288-298, especially the cross_entropy_splits > 1 case where accumulate_gradient(self.output_weights, ...) is called once per split and relies on param_grad_is_zero being correctly toggled across calls. Add a unit test that runs forward+backward with fp32_lm_head=True (and again with cross_entropy_splits=2) and compares output_weights.grad_buffer to a non-FP32 reference within tolerance.

12. int(document_index.max().item()) at fast_llm/layers/language_model/loss/grpo.py:440 triggers a host sync on every GSPO microbatch, and the SDP n_segs all_reduce(MAX) at line 443 is a separate collective from the lrn_sum/adv_sum/tok_sum all_reduce(SUM) at lines 461-463. Compute n_segs once at batch construction (it's already known from length_cumsum in _get_label_counts), thread it through kwargs, and drop both syncs.

13. torch.ones(masked_doc_ids.numel(), device=logits.device) at fast_llm/layers/language_model/loss/grpo.py:457 allocates a fresh ones-tensor every forward, only to feed index_add_. Use torch.bincount(masked_doc_ids, minlength=n_segs).to(log_ratio.dtype) — no allocation, no separate index_add pass.

14. LanguageModelGRPOLossConfig.normalize_by_documents's description at fast_llm/layers/language_model/loss/config.py:230-236 says "Set to True when using docs_per_step for full DS parity" but nothing in _validate enforces or even cross-references the two flags (which live in different configs). Either add a cross-config consistency check on the trainer/experiment config, or drop the load-bearing-sounding sentence from the field doc.

Notes

  • The decision-rich comment block at fast_llm/layers/language_model/loss/grpo.py:78-91 (the loss-vs-grad divisor derivation) belongs in a design doc or commit-message reference rather than in the hot-path forward — but it's the kind of "why" CLAUDE.md permits, and removing it would lose hard-won context. Flagging only as a maintenance hazard: the line numbers it cites in DeepSpeed will drift.

  • LanguageModelGRPOLoss.get_preprocessing_config at fast_llm/layers/language_model/loss/grpo.py:240-243 adds return_document_index only when policy_loss == "gspo". If item 1 is acted on (sibling subclass), the override moves naturally into the GSPO subclass and the conditional disappears.

  • The _StubTrainer pattern in tests/layers/test_docs_per_step.py:996-1001 (borrowing _prefetch_to_doc_target directly off Trainer) silently breaks if the method gains an attribute access not stubbed on the fake. Prefer extracting the prefetch logic into a free function that takes its dependencies as arguments — the test then exercises the same function the trainer calls.

Base automatically changed from grpo-metrics to main May 6, 2026 16:38
jlamypoirier and others added 2 commits May 6, 2026 12:52
# Conflicts:
#	fast_llm/layers/language_model/loss/config.py
#	fast_llm/layers/language_model/loss/grpo.py
- Drop unused self._preprocessing_config store in Trainer.setup.
- Replace torch.ones + index_add_ with torch.bincount for tok_sum
  in fused_gspo_loss_forward_backward.
- Drop load-bearing-sounding docs_per_step reference from the
  normalize_by_documents field description (no cross-config check
  exists to enforce it).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants