Skip to content

LR scheduler exhausts early in agentic training with AgentNativeStepEnvManager #407

@shamanez

Description

@shamanez

Bug Description

When using AgentNativeStepEnvManager (step-level env manager) for agentic training, the LR scheduler exhausts its step budget far before all pipeline steps complete, causing the learning rate to drop to zero mid-training.

In a 200-step training run with lr_scheduler_type: "linear", the LR reached zero at pipeline step 123 — meaning 38.5% of training happened with zero learning rate and no learning.

Root Cause

PPOConfig.set_max_steps() computes the total optimizer steps for the LR scheduler using rollout_batch_size (number of trajectories):

https://github.com/alibaba/ROLL/blob/main/roll/configs/base_config.py#L701-L718

self.actor_train.training_args.max_steps = max(1, (
    max_steps
    * self.rollout_batch_size              # trajectories per rollout
    * self.actor_infer.generating_args.num_return_sequences
    * self.ppo_epochs
    // actor_backward_batch_size
))

With the config in agent_val_rock_swe_qwen35_2b.yaml:

max_steps=200, rollout_batch_size=4, num_return_sequences=1, ppo_epochs=1, backward_batch_size=4
→ scheduler total = 200 * 4 * 1 * 1 // 4 = 200 optimizer steps

But the training batch contains chunks (one per agent turn), not trajectories. AgentNativeStepEnvManager.formulate_rollouts() creates one training sample per turn:

https://github.com/alibaba/ROLL/blob/main/roll/pipeline/agentic/env_manager/agent_native_env_manager.py#L248

for step, history in enumerate(rollout_cache.history):
    # ... one DataProto per turn ...
    samples.append(lm_input)
batch = DataProto.concat(samples)  # all turns as separate samples

So with 4 trajectories × ~10 turns each = ~40 training samples per pipeline step. With backward_batch_size=4, that's ~10 optimizer steps per pipeline step — not the 1 that the scheduler was budgeted for.

Additionally, batch_adjust_mode: "random_sample" has a bifurcation that makes this worse:

  • When total_chunks % backward_batch_size == 0: keeps ALL chunks → many optimizer steps
  • When not divisible: subsamples to exactly backward_batch_size → 1 optimizer step

https://github.com/alibaba/ROLL/blob/main/roll/pipeline/agentic/agentic_pipeline.py (search for adjust_batch)

This creates wildly inconsistent optimizer step counts per pipeline step, making the scheduler exhaustion unpredictable.

Evidence from Training Run

wandb run: https://wandb.ai/shamanework-pl/roll-agentic/runs/gvoe0mq8

Config: openreward_endless_terminals_IPA_qwen35_2b.yaml (same architecture, rollout_batch_size=16, backward_batch_size=16)

Pipeline Step Backward Steps Cumulative Optimizer Steps LR
0 1 1 9.95e-7
2 15 17 9.15e-7
21 20 55 7.25e-7
64 11 118 4.10e-7
108 12 178 1.10e-7
123 6 205 0.00
199 1 313 0.00

Total: 313 optimizer steps across 200 pipeline steps, but scheduler budgeted for 200. LR hit zero at step 123.

Affected Configs

Any agentic config using AgentNativeStepEnvManager with a decaying LR scheduler (linear, cosine, etc.):

  • examples/agentic_demo/agent_val_rock_swe_qwen35_2b.yaml
  • Any similar agentic training config

Suggested Fix

Option A — Use constant LR (simplest, no code change):

actor_train:
  training_args:
    lr_scheduler_type: "constant_with_warmup"
    warmup_steps: 10

Option B — Fix set_max_steps for agentic training (proper fix):

Override set_max_steps in AgenticConfig to account for the actual number of chunks per trajectory:

# In AgenticConfig:
def set_max_steps(self, max_steps: int):
    actor_backward_batch_size = (
        self.actor_train.training_args.per_device_train_batch_size
        * self.actor_train.training_args.gradient_accumulation_steps
    )
    # Estimate chunks per trajectory (each turn = 1 training sample)
    estimated_avg_turns = self.max_actions_per_traj // 2  # conservative midpoint
    self.actor_train.training_args.max_steps = max(1, (
        max_steps
        * self.rollout_batch_size
        * estimated_avg_turns
        * self.ppo_epochs
        // actor_backward_batch_size
    ))

Option C — Fix batch_adjust_mode (complementary):

Change random_sample to always produce a consistent batch size (e.g., always round down with "delete" mode), so each pipeline step = exactly 1 optimizer step, matching the current set_max_steps formula.

Environment

  • ROLL version: main branch
  • Model: Qwen3.5-2B
  • Environment: SWE-bench / OpenReward EndlessTerminals
  • GPUs: 8× (TP=2, CP=2 for training, 8× vLLM inference)

/cc @shamanez

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions