Skip to content

Implement forced lag in RL#3517

Merged
tdene merged 37 commits intoNVIDIA:mainfrom
tdene:tde/rl_forced_lag
Mar 19, 2026
Merged

Implement forced lag in RL#3517
tdene merged 37 commits intoNVIDIA:mainfrom
tdene:tde/rl_forced_lag

Conversation

@tdene
Copy link
Contributor

@tdene tdene commented Feb 20, 2026

What does this PR do ?

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 20, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

group_index = yielded_groups
yielded_groups += 1
for rollout in group:
rollout.submission_index = group_index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove group index and do rollout.submission_index = yielded_groups -1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

def get_rollout_generator(args, inference_interface, n_prompts, samples_per_group):
global _ROLLOUT_GENERATOR
if not args.rl_partial_rollouts or _ROLLOUT_GENERATOR is None:
oversubscribed = args.rl_partial_rollouts or args.rl_forced_lag > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why you are calling this 'oversubscribed'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not think of a better word at the time. It's regarding the concept of submitting more requests than you actually need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's why I asked! partial rollouts or forced lag are not the only cases when we sample more than we consume. Sending a request when batch_size < group_size * prompts will also do this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now called streaming.

assert args.micro_batch_size == 1, \
"micro_batch_size must be 1 when using sequence packing. To increase compute per micro batch increase the sequence length."
assert rl.forced_lag > 0 or not args.rl_partial_rollouts, (
"--rl-forced-lag and --rl-partial-rollouts are incompatible."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

buffered_rollouts is None or
iteration == runtime_state.last_collection_iteration +
runtime_state.data_iterator is None or
iteration >= runtime_state.last_collection_iteration +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why >=?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is now entirely stale code.

RerunDataIterator for the current training step
"""
runtime_state = get_rl_runtime_state()
args = get_args()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, send the argument as a function arg, do not call get_args() here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed!

):
if forced_lag > 0:
runtime_state.lag_buffer.append(
get_environment_rollouts(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be above the if forced_lag branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is now entirely stale code.

model, inference_model, optimizer, grpo_prompts_per_step, grpo_group_size,
)
)
rollouts = runtime_state.lag_buffer.popleft()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be useful/interesting to track the length of the lag buffer. Is it gonna increase in time?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also! Related to the staleness tracking, without the plotting fix, the staleness will be reported incorrectly even more.


runtime_state.reset_iteration_counters(iteration)
return runtime_state.data_iterator
return runtime_state.data_iterator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With those changes, we need to run tests for these scenarios.

  • Normal, GRPO_ITERATIONS=1 sample X, consume X (batch size = prompts * group_size)
  • sample 2X, consume X batch size = prompts * group_size / 2
  • your newly added lag thing. Not sure if we need it in both scenarios above.

if runtime_state.start_iteration is None:
runtime_state.start_iteration = iteration
if forced_lag > 0:
runtime_state.lag_buffer = deque()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me this is kind of stuff __init()__ is for. Why do we need all this logic here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is now entirely stale code.

@tdene tdene force-pushed the tde/rl_forced_lag branch 3 times, most recently from 241f1bf to 95815f4 Compare March 4, 2026 10:09
@tdene tdene force-pushed the tde/rl_forced_lag branch from fe5d072 to edb5149 Compare March 4, 2026 20:32
@tdene tdene marked this pull request as ready for review March 4, 2026 20:32
@tdene tdene requested a review from a team as a code owner March 4, 2026 20:32
@tdene tdene added the Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. label Mar 4, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team March 4, 2026 20:32
@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Mar 4, 2026
inference_interface: InferenceInterface
validation: bool = False
filter_groups_with_same_reward: bool = False
streaming: bool = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New parameter that matches up to what num_groups = -1 meant in the old code. It's bad practice to have a parameter that means one thing when it's positive, and an entirely different thing when it's set to -1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those will break nemo_gym integration. Please, adjust those scripts accordingly. ./nemo_gym_agent.py is the file you need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do this later today.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get back to this after we deal with completed_at_step parameter.

validation: bool = False
filter_groups_with_same_reward: bool = False
streaming: bool = False
batch_results: bool = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When False, this returns groups at a time like in the old flow. When True, this waits until num_groups are ready and then returns them all at once.

# The semaphore ensures that each batch only starts after the previous is consumed.
groups_per_worker = request.num_groups
num_workers = self.parallel_generation_tasks // groups_per_worker
submission_gate = asyncio.Semaphore(num_workers)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point of this submission_gate is the following.

If we are doing forced lag, we can force the RL setup to only "generate" a certain number of groups per step. This is the flow without submission_gate. This is more performant, but does not guarantee consistent lag.

The flow with submission_gate forces the RL setup to only "consume" a certain number of groups per step. This is less performant, but guarantees consistent lag.

The difference between "generate N per step" and "consume N per step" is in the edge-case.

request.inference_interface, InferenceServer
), "Rollout requests to remote server must contain an InferenceServer object"
assert request.num_groups != -1, "FastAPIEnvServer does not support group rollout streaming"
assert not request.streaming, "FastAPIEnvServer does not support group rollout streaming"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I meant above. num_groups != -1 was secretly hidden behavior. Now it's out in the open.

streaming = args.rl_partial_rollouts or args.rl_forced_lag > 0
if not streaming or _ROLLOUT_GENERATOR is None:
if args.rl_forced_lag > 0:
pgt = (args.rl_forced_lag + 1) * n_prompts
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must control the inference batch size in forced lag mode.

@jon-barker
Copy link
Contributor

/claude review

group.add_argument('--rl-forced-lag', type=int, default=0,
help='Forced rollout lag of L steps. After an initial warm-up of L steps, '
'All steps N+L use only rollouts that were started on step N. '
'0 (default) disabled this behavior.'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "disabled" → "disables"

Suggested change
'0 (default) disabled this behavior.'
help='Forced rollout lag of L steps. After an initial warm-up of L steps, '
'All steps N+L use only rollouts that were started on step N. '
'0 (default) disables this behavior.'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved. Thank you Claude!

if submission_gate is not None:
await submission_gate.acquire()
batch_id = submitted_groups // groups_per_worker
submitted_groups += groups_per_worker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assert only guards the groups_per_worker > 1 path. When groups_per_worker == 1 (i.e. n_prompts == 1) with batch_results=True and filter_groups_with_same_reward=True, a filtered group will leave a semaphore permit unreleased. The consumer is then waiting on grouped_rollouts.get() while every worker is stuck waiting on submission_gate.acquire() — deadlock.

Since filtering + forced lag is unsupported, it should be caught in validate_args rather than at runtime inside the generator (where the assert kills the training run with no cleanup). Consider adding to validate_args:

assert not (args.rl_forced_lag > 0 and args.rl_filter_groups_with_same_reward), \
    "--rl-forced-lag and --rl-filter-groups-with-same-reward are incompatible."

(adjust the flag name to whatever the actual arg is.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved. Thank you Claude!

@ArEsKay3
Copy link
Contributor

ArEsKay3 commented Mar 7, 2026

I actually have pretty strong feeling about the arguments here.
If I want to change something I shouldn't have to adjust lot of others.
As it stands a change in say group size will change the number of active requests (even if the global batch size is adjusted accordingly), something I think we want to avoid.

So. My proposal is
--rl-num-parallel-generations OR --rl-num-parallel-generation-batches

They would be mutually exclusive. Both would require --rl-partial-rollouts to be set.

The first would imply no batching and would represents the number of parallel rollouts. This is what we really want to control. We would internally adjust for multiple tasks and divide it by grpo_group_size to ensure that changing the grpo_group_size or number of tasks doesn't require an associated change in this argument to keep the same amount of "active" rollouts.

The second would imply batching and would represent the number of batches active at a time. In this case we would set our controls so we have that batches in flight. This would be our "forced lag" or "N step off policy" setup.



@dataclass(slots=True)
class RolloutGroup:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it inherit from the BaseModel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

while grouped_rollouts.qsize() > 0 or not all(task.done() for task in tasks):
yield await grouped_rollouts.get()
next_batch_id = 0
pending: dict[int, list[RolloutGroup]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GroupedRollouts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

class RolloutGroup:
"""A group of rollouts (e.g. multiple completions for one prompt) with batch metadata."""

rollouts: list[Rollout]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rollout or Rolout|TokenRollout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if request.enforce_order:
# Accumulate groups and enforce submission order across batches.
pending.setdefault(group.batch_id, []).append(group)
while len(pending.get(next_batch_id, [])) >= groups_per_worker:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand this logic. Is there a chance we do not get to this loop? i.e. not get out of it because we have exhausted pending, but for some other reasons.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly. The only way we can get stuck here is if we filter out rollouts (which we are asserting that we do not do), or if the worker tasks die. But if the worker tasks die, we shutdown the asyncio Queues (because I pack-ported Python 3.13 functionality into this repo, because asyncio Queues are broken before Python 3.13), and that breaks us out of a potentially infinite loop.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Approved All necessary approvals have been made label Mar 18, 2026
@tdene tdene enabled auto-merge March 18, 2026 17:50
@tdene tdene added this pull request to the merge queue Mar 18, 2026
@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23271910667

@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23273525669

@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 19, 2026
@tdene tdene added this pull request to the merge queue Mar 19, 2026
@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23275855415

@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23278219551

Merged via the queue into NVIDIA:main with commit dde4701 Mar 19, 2026
55 of 57 checks passed
@tdene tdene deleted the tde/rl_forced_lag branch March 19, 2026 06:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: medium

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants