RL training features (#502 minus GSPO)#520
Open
jlamypoirier wants to merge 27 commits into
Open
Conversation
Adds GRPO metrics parity with DeepSpeed: old_logprobs, ratio, ratio_sum, ratio_sq_sum, kl_new_old, clamp_frac, advantage, max/min_advantage, num_tokens, and optional per-token entropy. New files: - fast_llm/layers/language_model/loss/pg_metrics.py: reusable PolicyGradientMetrics dataclass + compute_policy_gradient_metrics() (callable by future PPO), with chunked vocab-parallel entropy support. - tests/layers/test_grpo_metrics.py: 8 unit tests covering single-seq, packed multi-seq, masked tokens, clamp fraction, entropy correctness, mock SDP correctness, mock vocab-parallel entropy, normalization parity. Config additions to LanguageModelGRPOLossConfig: - compute_extra_metrics (default False): log all non-entropy metrics - compute_entropy_metric (default False): additionally log per-token entropy - entropy_chunk_size (default 4096): batch chunk size for entropy pass Normalization matches existing new_logprobs_mean: sum(v*mask/label_counts) then divided by num_documents_in_batch. MAX/MIN use LossDef ReductionType and correct ReduceOp so they aggregate correctly across microbatches and SDP/sequence-parallel ranks.
Rename four metrics to match DeepSpeed's naming exactly so runs on both backends produce comparable WandB keys: ratio → ratio_new_old ratio_sum → ratio_new_old_sum ratio_sq_sum → ratio_new_old_squared_sum clamp_frac → clamp_log_ratio_new_old_indicator
Implements GSPO (geometric-mean sequence-level policy-gradient loss) as an alternative to the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo". Key changes: - data pipeline: expose per-token document_index when return_document_index=True - LanguageModelKwargs.document_index: new kwarg constant - LanguageModelLoss: store SDP dim for cross-rank segment aggregation - grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across SDP ranks before computing segment-level R_s and A_s; gradient derivation exploits tok_count cancellation so every token in a segment gets the same gradient factor R_s * clip_indicator_s - tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed, ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff, per-token metrics unchanged)
Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict
computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel ×
breadth_first_micro_batches) before sub-configs are created (and frozen).
Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1
each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8
gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally.
YAML usage:
schedule:
rollouts_per_step: 1024 # replaces manual depth_first_micro_batches
model:
distributed:
data_parallel: 8 # used for the division
- Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first
is now determined at runtime rather than statically in _from_dict
- Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs}
properties so per-step schedules share the same config object as the runner
- Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time,
all-reduces doc count per microbatch, stops when global total ≥ docs_per_step,
then resets num_documents_in_batch to the step total on all inputs
- Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with
_depth_first_override=N//breadth_first_micro_batches
- Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True
both GRPO and GSPO paths divide by num_documents_in_batch instead of
num_labels_in_batch (matches DeepSpeed's per-rollout normalization)
- Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor
scaling, normalize_by_documents layer routing, Schedule._eff_* properties,
and _prefetch_to_doc_target accumulation logic
Add temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probs are computed at the same temperature as the stored old log-probs, so the IS ratio starts near 1.0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted for logits_scale_factor at all three callsites in _forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default temperature=1.0 preserves existing behaviour exactly.
Add fp32_lm_head to LanguageModelHeadConfig. When enabled, input hidden states and output_weights are cast to float32 before the lm_head linear, producing FP32 logits. This matches vLLM's bf16_last_layer_fp32 quantization (pipelinerl/vllm_quantization.py) and the DeepSpeed trainer's apply_fp32_lm_head() patch, so new_logprobs and old_logprobs are computed at the same numerical precision and the IS ratio starts near 1.0 at init. The gradient flowing back through the linear is cast to the original input dtype (bf16) before returning, keeping the transformer backward pass in its native dtype.
…accumulation Detaching the FP32 weight copy (requires_grad=False) prevents output_parallel_linear_backward from trying to write to a non-existent grad_buffer on the copy. Weight grad is then computed explicitly from the FP32 matmul and accumulated into the original BF16 param's grad_buffer via accumulate_gradient, restoring the correct FSDP gradient contract.
When normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024×
larger than DeepSpeed's for the equivalent loss, causing the default
gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10
reward points slower than DS GSPO at the same step count. The lm_head_loss
metric was also off — 1024× smaller than DS's rl/loss in the previous
divisor=num_documents² formulation, then 2× too large from SDP doubling.
Root cause analysis
-------------------
DeepSpeed has TWO 1/batch_size factors with different sources:
1. Loss reported (rl/loss) uses /batch_size via tokens_weights = 1/batch_size
(pipelinerl/finetune/rl/__init__.py:246). The reported `rl/loss = -1.7`
value is the raw policy_loss_total, divided once by batch_size.
2. Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes
from `scale_wrt_gas=True` in engine.backward()
(deepspeed/runtime/engine.py:1995-1996) and `tensor.div_(world_sz)` in
reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124).
For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size
= batch_size, so DS's effective gradient buffer factor is 1/batch_size² while
the loss metric factor is 1/batch_size. Loss and gradient have asymmetric
scaling.
Fast-LLM's existing implementation used a single `divisor` for both loss and
gradient. Worse, the data_parallel × grad_scale factor in grad_output
(runner.py:318) cancels with FSDP's RS-AVG /world_size, structurally removing
DS's /(gas × world_size) factor from the gradient. So fast-LLM's gradient
buffer ended up at 1/batch_size while DS's was at 1/batch_size² — a
~batch_size = 1024× mismatch.
Additionally, GSPO's SDP allreduce of lrn_sum/adv_sum/tok_sum makes both SDP
ranks compute IDENTICAL per-segment loss values. When LossDef.reduce sums
over the data_group (which includes SDP ranks), the loss metric is
double-counted by sdp_size. The gradient buffer is NOT double-counted —
each SDP rank contributes gradient from its own LOCAL tokens, with different
contributions for different tokens of the same segment.
Fixes
-----
1. Add a `grad_divisor` parameter to `fused_gspo_loss_forward_backward`,
`fused_grpo_loss_forward_backward`, and `triton_grpo_loss_forward_backward`,
defaulting to `divisor` (existing behavior). Allows the gradient to use a
different divisor than the loss.
2. In `LanguageModelGRPOLoss._forward_backward`, when normalize_by_documents
is True, set:
loss divisor = num_documents_in_batch (matches DS rl/loss)
gradient divisor = num_documents_in_batch² (matches DS grad_norm)
This is independent of TP/PP/SDP/DP parallelism and microbatching schedule
because batch_size is invariant under all of these.
3. In the GSPO path, divide the loss by sdp_size when sdp_group is active
(`fused_gspo_loss_forward_backward`). This pre-cancels the SDP doubling
that LossDef.reduce's SUM over data_group introduces. The gradient is
unaffected — different SDP ranks naturally contribute gradient from
different LOCAL token positions, no double-counting at any layer.
Verification
------------
Tested on 7B math run with 4 nodes, GSPO, gradient_norm_clipping=0.3:
Before fix | After fix | DS GSPO reference
------------------- | ------------------ | ------------------
step 1 grad_norm=141| step 1 grad_norm=0.135 | step 1 grad_norm=0.145
step 1 lm_head_loss | step 1 lm_head_loss | step 1 rl/loss
= -13.7 | ~ -1.7 (sign varies | = -1.7
| per data sample) |
clip_coeff=0.002 | clip_coeff=1.000 | no clipping at step 1
newlp at step 50 | newlp at step 50 | newlp at step 50
trapped at -0.17 | = -0.103 | = -0.105
newlp trajectory tracks DS step-by-step: step 1 within 3%, step 50 within 2%.
Both systems show grad_norm spikes at the same training phase (steps 14-20)
during warmup ramp-up — DS step 16 grad_norm=6.365 vs Fast-LLM 6.093.
Files changed
-------------
- fast_llm/layers/language_model/loss/grpo.py:
- LanguageModelGRPOLoss._forward_backward: split divisor and grad_divisor
based on normalize_by_documents flag, with detailed comments referencing
the corresponding lines in DeepSpeed and PipelineRL.
- fused_gspo_loss_forward_backward: add grad_divisor parameter; divide loss
by sdp_size when sdp_group is active.
- fused_grpo_loss_forward_backward: add grad_divisor parameter.
- fast_llm/functional/triton/grpo_loss.py:
- triton_grpo_loss_forward_backward: add grad_divisor parameter.
- Inline pg_metrics.py into grpo.py; rename to GRPOMetrics - Drop entropy_chunk_size; reuse fused_softmax_base outputs for entropy - Replace two bool flags with a single metrics: GRPOMetricsLevel enum - Rename clamp_log_ratio_new_old_indicator -> clipped_ratio_fraction - Raise on metrics enabled with pipeline_parallel > 1 (MAX/MIN reduce would be corrupted by the zero placeholder on empty pipeline ranks) - Migrate tests into tests/layers/test_lm_losses.py, reusing the existing helpers and parametrization (single + distributed runner) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Drop stale "second softmax pass" overhead note from `metrics` description (entropy now reuses the base softmax outputs) - De-mirror max/min in reference_grpo_metrics: use advantages[loss_mask].max()/.min() instead of the implementation's -inf/+inf sentinel pattern Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Align (logits, target, advantages, old_log_probabilities, ...) order across compute_grpo_metrics, fused_grpo_loss_forward_backward, and reference_grpo_metrics - Replace **kwargs in LanguageModelGRPOLoss.__init__ with the explicit keyword-only signature mirroring LanguageModelLoss.__init__ - num_docs -> num_documents - Drop the comment that restated the k3 KL formula - Give compute_grpo_metrics the same defaults as the loss kernel - Trim the metrics field description to category-level wording - Always exercise varying label_counts in _test_grpo_metrics so per-token denominator broadcasting is covered - reference_grpo_metrics returns GRPOMetrics; comparison loop iterates dataclasses.fields - Drop name = self._name micro-rebinds; use self._name inline - defs = super()...; defs.append(...); defs.extend(...) consistently - Tighten _register_extra_metrics losses type to dict[str, list[Tensor]] - Split compiled tuple-returning core from outer GRPOMetrics wrapper to avoid @torch.compile graph-breaks on dataclass construction - One-line comment on the metrics gate explaining the softmax-skip Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NamedTuple is a tuple subclass that dynamo handles natively, so the previous wrapper/inner split (added to dodge a dataclass graph-break) collapses into one @torch.compile function. Field order now lives exactly once — on the class. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Entropy under vocab-parallel TP was wrong: the dot-product term (exp_logits * logits_norm).sum(-1) summed only the local vocab slice, so dividing by the global sum_exp_logits gave a per-rank fragment instead of the full E_p[logit_norm]. All-reduce the partial sum. - Replace the verbose pipeline-parallel guard with Assert.custom; the field description already explains the constraint. - Drop the cryptic `# k3` comment. - Match _register_extra_metrics losses annotation to the base class (dict | None). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts: # fast_llm/layers/language_model/loss/config.py # fast_llm/layers/language_model/loss/grpo.py
# Conflicts: # fast_llm/layers/language_model/loss/config.py # fast_llm/layers/language_model/loss/grpo.py
- Drop unused self._preprocessing_config store in Trainer.setup. - Replace torch.ones + index_add_ with torch.bincount for tok_sum in fused_gspo_loss_forward_backward. - Drop load-bearing-sounding docs_per_step reference from the normalize_by_documents field description (no cross-config check exists to enforce it). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Splits the policy-gradient loss config and class hierarchy: - LanguageModelPolicyGradientLossConfig (abstract base): shared fields (epsilon_low/high, metrics, normalize_by_documents, temperature). - LanguageModelGRPOLossConfig: registers `type: grpo` (keeps GRPO-only use_triton). - LanguageModelGSPOLossConfig: registers `type: gspo`. - LanguageModelPolicyGradientLoss (abstract base): shared __init__/_forward_backward/_register_extra_metrics/get_loss_definitions/ get_preprocessing_config plumbing; abstract `_call_kernel`. - LanguageModelGRPOLoss / LanguageModelGSPOLoss: each implements `_call_kernel` against its kernel; GSPO overrides `get_preprocessing_config` to add `return_document_index`. Drops the stringly-typed `policy_loss: str` switch and the in-method if/else dispatch, addressing review items #1 and #5 plus Note 2. YAML migration: `type: grpo` + `policy_loss: gspo` → `type: gspo`. No checked-in YAML configs use the old form. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the abstract `_call_kernel` + per-algorithm subclass pattern with the assignment-at-init pattern used by `Normalization._forward`. - Single LanguageModelPolicyGradientLoss class hosts both kernel calls as `_call_grpo_kernel` and `_call_gspo_kernel`. - __init__ assigns `self._call_kernel` to the matching method based on isinstance(config, LanguageModelGSPOLossConfig). - get_preprocessing_config dispatches inline on the same isinstance. - Both LanguageModelGRPOLossConfig and LanguageModelGSPOLossConfig return the same loss class — the YAML-side type split (registered via @config_class(dynamic_type=...)) stays as in #1. Drops ~30 lines net from grpo.py: removes the abstract `_call_kernel` declaration and the two single-method subclasses. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reverts the class merge from d2c051a in favor of the assignment-at-init pattern used by Normalization._forward. Drops the per-call _call_kernel wrapper that just shuffled args. - LanguageModelPolicyGradientLoss now hosts only shared scaffolding: _compute_divisors (token vs document), _shared_kernel_kwargs (the 9 kwargs both kernels accept), _finalize_loss (post-call register + extra metrics), and the per-token metrics machinery. - LanguageModelGRPOLoss and LanguageModelGSPOLoss are restored. Each __init__ assigns self._forward to the actual kernel function: GRPO: triton_grpo_loss_forward_backward or fused_grpo_loss_forward_backward GSPO: fused_gspo_loss_forward_backward - Each subclass's _forward_backward calls self._forward(...) directly with the kernel's real signature; no intermediate wrapper. - Configs map type:grpo → LanguageModelGRPOLoss, type:gspo → LanguageModelGSPOLoss again. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts: # fast_llm/layers/language_model/config.py # fast_llm/layers/language_model/head.py
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
GSPO core landed via #517 (cleaner reimplementation from main). Drop the GSPO loss class, config, kernel, dedicated test file, and the supporting LanguageModelKwargs.document_index plumbing (#517 reads document_index_q from BlockKwargs instead). Also drop two GSPO-specific knobs that no longer apply once GSPO is removed: - normalize_by_documents on LanguageModelPolicyGradientLossConfig — was GRPO/GSPO's DS-style /M loss with /M^2 gradient. The GSPO loss in #517 bakes /num_documents in unconditionally and the GRPO path here keeps the per-token-count normalization. - The kernel's /sdp_size "fix" only existed in the GSPO kernel (global per-segment loss made identical on every SDP rank); deleted with the GSPO kernel. The rest of #502 (docs_per_step, fp32_lm_head, grad_divisor parameter on GRPO kernels) is preserved as-is for follow-up review. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Follow-up to #517 (which took the GSPO core from #502 and reimplemented it cleanly from main).
What this PR is
The remainder of #502 with the GSPO-specific content removed. Built on top of the original
gspobranch with just three deletions on top.What was removed
LanguageModelKwargs.document_indexand its data-pipeline plumbing — only needed for the GSPO kernel. Add GSPO loss #517 readsdocument_index_qfromBlockKwargsinstead.normalize_by_documentsflag onLanguageModelPolicyGradientLossConfig— was the DS-style/Mloss with/M^2gradient. Add GSPO loss #517 bakes/num_documentsinto GSPO unconditionally; the GRPO path here keeps the per-token-count normalization./sdp_size"fix" — only existed in the GSPO kernel; deleted with it.What's kept (unchanged from #502)
Schedule._eff_*properties,Trainer._prefetch_to_doc_target, and the corresponding unit tests.head.py.grad_divisorparameter onfused_grpo_loss_forward_backwardandtriton_grpo_loss_forward_backward— allows the gradient to use a different divisor than the loss. Currently always defaults todivisor(callers no longer pass a different value), but the plumbing is in place.Diff stats
7 files changed, 11 insertions(+), 808 deletions(-)— net delete-only.Test plan
pytest tests/layers/test_lm_losses.py::test_grpo_loss tests/layers/test_lm_losses.py::test_grpo_metrics tests/layers/test_docs_per_step.py— 70 cases passNote: this branch is behind
mainby one commit (#508). Will rebase before merge.