arguments: add --enable-r3-correctness-check CLI flag#25
Open
DavidBellamy wants to merge 5 commits into
Open
Conversation
When set, flips RoutingReplayManager.enable_check_replay_result = True
so the per-step overlap check (replay_base.py:178-219) fires for every
training step. Off by default because the check roughly doubles the
cost of routing.
Intended for the R3 regression E2E on LLM360/RL360, which runs a small
GPU sbatch on M2 every time a submodule-pin bump PR opens. With this
flag, miles will raise AssertionError("R3 mismatch tokens ...") if the
overlap drops below MILES_TEST_R3_THRESHOLD (default 1e-2), giving the
E2E a hard pass/fail signal.
The R3 master switch (--use-rollout-routing-replay) is still required;
this flag has no effect without it.
Six files on the prod base had black-non-compliant formatting that pre-commit on PR #25 flagged as failures. Applying `black==24.3.0` (matches .pre-commit-config.yaml) brings them in line so CI passes. Also fixes the single line in train_async.py from this PR that black wants (blank line after the import). No behavioral changes; pure whitespace + line breaks.
The previous --enable-r3-correctness-check flag turned on the overlap check but produced no log output unless an actual mismatch happened, making it impossible to distinguish "check passed" from "check never ran." Add two unconditional logs gated on enable_check_replay_result: 1. get_topk_fn / new_topk_fn replay_forward + replay_backward branches: log when the wrapper actually returns replay indices rather than falling through to old_topk_fn. Direct evidence megatron's MoE forward used the rollout indices (vs recomputing them). 2. check_replay_result: log n_tokens and mismatch_count on every call, including the mismatch_count==0 case (which previously returned silently). Direct evidence the check ran, plus the actual overlap number for cross-step / cross-rank comparison. Both logs gated on enable_check_replay_result so production training runs (which leave it False) stay quiet. Adds no overhead when off. Intended to make the LLM360/RL360 R3 regression E2E able to assert directly that R3 worked end-to-end, rather than inferring from absence of failure messages.
actor.py:111-112 unconditionally set m.enable_check_replay_result = m.enabled and self.args.ci_test which overrode the value we set in train_async.py from --enable-r3-correctness-check. The flag was effectively a no-op. This change keeps backward-compat for --ci-test and ALSO honors --enable-r3-correctness-check on its own, so callers can enable the R3 overlap check without enabling the rest of --ci-test's invariants. In particular --ci-test also enables a strict log_probs == ref_log_probs equality check that trips on routine floating-point precision differences (~1e-3 gap), so R3 callers need a way to opt into ONLY the replay check. Found during R3 E2E pre-merge validation: with --ci-test on, the R3 overlap check fired cleanly (1976+ checks, all mismatch=0%) but the job then failed at the unrelated log_probs assertion before the backward pass. With --enable-r3-correctness-check now wired through properly, the same run reaches backward and can show replay_backward branch evidence too.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this does
Three small, targeted miles changes to enable direct end-to-end validation of R3 (Rollout Router Replay) without dragging in unrelated
--ci-testinvariants.A. New CLI flag
--enable-r3-correctness-check(miles/utils/arguments.py) that flipsRoutingReplayManager.enable_check_replay_result = True. The flag also writes the value at module level intrain_async.pyafterparse_args()(line 78-81).B. Make the flag actually take effect (
miles/backends/megatron_utils/actor.py). Previouslyactor.pyunconditionally overwrote the value toself.args.ci_testfor every replay manager (single-line assignment at L112), so the new flag was a no-op. Extends that condition to also honor--enable-r3-correctness-check:This lets callers turn on just the R3 overlap check without enabling
--ci-test's other strict-equality invariants — in particular, thelog_probs == ref_log_probscheck that trips on routine floating-point precision differences (~1e-3 gap) and was breaking the E2E run before this fix.C. Add direct-evidence logs in
RoutingReplayManager(miles/utils/replay_base.py+ check). Two unconditionallogger.infocalls, gated onenable_check_replay_resultso production training stays quiet:new_topk_fn'sreplay_forwardandreplay_backwardbranches that logsR3 wrapper: replay_{forward,backward} branch taken (rank ..., n_tokens=..., topk=..., replay_idx_sum=...). Without this, there is no log evidence that megatron actually called through the wrapper (vs falling through toold_topk_fn).check_replay_resultthat always logsR3 check (rank ..., stage ...): n_tokens=... mismatch=... (...%). Previously the check returned silently whenmismatch_count == 0, making "check passed" indistinguishable from "check never ran."These three logs together let the LLM360/RL360 E2E daemon prove that megatron's MoE forward and backward both used the rollout indices, not just that no assertion fired. See LLM360/RL360#317 validation evidence section.
Why this matters
Without this PR, the R3 regression E2E on M2 has no way to assert direct correctness of R3. The previous
replay_base.py:178-219check existed but was only callable via--ci-test, which trips a separatelog_probs == ref_log_probsstrict-equality check on the same run and crashes before backward fires.Commits in this PR (substantive vs formatting)
Sorted by what kind of change they make:
Substantive (please review)
0431dbf5—arguments: add --enable-r3-correctness-check CLI flagmiles/utils/arguments.py,train_async.py0854adcc—replay_base: direct-evidence logs for R3 wrapper + overlap checkmiles/utils/replay_base.pyf019a625—actor: make --enable-r3-correctness-check independent of --ci-testmiles/backends/megatron_utils/actor.pyFormatting only (mechanical, no behavior change)
db437d22—prod: apply black drift cleanupblack==24.3.0(the version pinned in.pre-commit-config.yaml) wanted to reformat 7 files when CI ranpre-commit run --all-files. Six were pre-existing drift on theprodbase; the seventh is one blank line intrain_async.pyfrom this PR. 64 lines touched (+30/-34) acrosslog_utils.py,loss.py,rollout.py,openai_endpoint_utils.py,linear_trajectory.py,replay_base.py,train_async.py. No symbol added or removed.e06a6b3f—actor: apply black to the new condition (CI fix, no logic change)self.args.ci_test\n or getattr(...)collapsed onto one line per black's preference.Validation
LLM360/RL360#317 section "Validation evidence" walks through the full forward + backward proof using miles at this PR's head (
f019a625snapshot, behaviorally identical to current heade06a6b3f). SLURM job 1654622 (manual) and 1655337 (daemon-driven, posted byllm360-deploy-boton RL360 radixark#319) both COMPLETED with zero non-zero mismatches across thousands of per-rank-per-layer R3 wrapper calls in both forward and backward.