Skip to content

Commit 0431dbf

Browse files
committed
arguments: add --enable-r3-correctness-check CLI flag
When set, flips RoutingReplayManager.enable_check_replay_result = True so the per-step overlap check (replay_base.py:178-219) fires for every training step. Off by default because the check roughly doubles the cost of routing. Intended for the R3 regression E2E on LLM360/RL360, which runs a small GPU sbatch on M2 every time a submodule-pin bump PR opens. With this flag, miles will raise AssertionError("R3 mismatch tokens ...") if the overlap drops below MILES_TEST_R3_THRESHOLD (default 1e-2), giving the E2E a hard pass/fail signal. The R3 master switch (--use-rollout-routing-replay) is still required; this flag has no effect without it.
1 parent 1838542 commit 0431dbf

2 files changed

Lines changed: 16 additions & 0 deletions

File tree

miles/utils/arguments.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,19 @@ def add_algo_arguments(parser):
10241024
default=False,
10251025
help="The rollout routing replay technique from https://arxiv.org/abs/2510.11370",
10261026
)
1027+
parser.add_argument(
1028+
"--enable-r3-correctness-check",
1029+
action="store_true",
1030+
default=False,
1031+
help=(
1032+
"Run RoutingReplayManager's per-step overlap check that "
1033+
"recomputes the training-side topk on the same scores and "
1034+
"asserts overlap with the rollout indices. Roughly 2x routing "
1035+
"cost; off by default. Intended for the R3 regression E2E "
1036+
"(LLM360/RL360 scripts/r3-e2e/). No effect unless "
1037+
"--use-rollout-routing-replay is also set."
1038+
),
1039+
)
10271040
parser.add_argument(
10281041
"--use-opsm",
10291042
action="store_true",

train_async.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,7 @@ async def train(args):
7575

7676
if __name__ == "__main__":
7777
args = parse_args()
78+
if getattr(args, "enable_r3_correctness_check", False):
79+
from miles.utils.replay_base import RoutingReplayManager
80+
RoutingReplayManager.enable_check_replay_result = True
7881
asyncio.run(train(args))

0 commit comments

Comments
 (0)