Skip to content

Adding support for agentic grpo trainer.#3540

Merged
copybara-service[bot] merged 1 commit intomainfrom
nicogrande/async-rollouts
Apr 3, 2026
Merged

Adding support for agentic grpo trainer.#3540
copybara-service[bot] merged 1 commit intomainfrom
nicogrande/async-rollouts

Conversation

@NicoGrande
Copy link
Copy Markdown
Collaborator

@NicoGrande NicoGrande commented Apr 1, 2026

Description

Add support for the Tunix Agentic GRPO Learner, which enables asynchronous rollouts leveraging an online vLLM server.

To enable Agentic GRPO Learner, this PR introduces the rl.use_agentic_rollout flag. Similarly, the maximum amount of concurrency for the online vLLM server is set using rl.max_concurrency argument. Other arguments relevant to the Agentic GRPO Learner are also included in this PR.

Tests

Standard GRPO for qwen3-0.6b on v6e: 523.08s

Agentic GRPO for qwen3-0.6b on v6e: 363.93s

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 1, 2026

Codecov Report

❌ Patch coverage is 17.77778% with 37 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/post_train/rl/utils_rl.py 26.92% 19 Missing ⚠️
src/maxtext/trainers/post_train/rl/train_rl.py 5.26% 18 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NicoGrande NicoGrande force-pushed the nicogrande/async-rollouts branch 2 times, most recently from 38196ba to ec02199 Compare April 1, 2026 16:36
@NicoGrande NicoGrande requested a review from andytwigg April 1, 2026 17:31
Comment thread src/maxtext/trainers/post_train/rl/train_rl.py
Comment thread src/maxtext/trainers/post_train/rl/train_rl.py
@NicoGrande NicoGrande force-pushed the nicogrande/async-rollouts branch 5 times, most recently from 856c176 to fab1f04 Compare April 1, 2026 23:12
@NicoGrande NicoGrande force-pushed the nicogrande/async-rollouts branch 3 times, most recently from a1faa01 to aa94c87 Compare April 2, 2026 23:23
Copy link
Copy Markdown
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm with a couple of concerns that you can ignore if not relevant but I do note Andy's has one unresolved comment

Comment thread src/maxtext/trainers/post_train/rl/train_rl.py
Comment thread src/maxtext/trainers/post_train/rl/train_rl.py
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)


def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this also going to be performant enough?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From looking at the implementation of other chat parsers in Tunix, I think we should be fine: https://github.com/google/tunix/blob/main/tunix/rl/agentic/parser/chat_template_parser/parser.py

My biggest concern with this change was just matching the code we were previously using for pre-processing and apply chat templates exactly to avoid unintended bugs related to this.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmmm our parser already handles all these, if there's a missing model, it should be added to Tunix codebase instead

@NicoGrande NicoGrande force-pushed the nicogrande/async-rollouts branch from aa94c87 to daac9e0 Compare April 3, 2026 02:34
}


class MaxTextChatParser(agentic_chat_template_parser.DefaultChatTemplateParser):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this? it should just be qwen/gemma/llama parser

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed because of the diffreence between how maxtext and tunix do parsing; tunix subclasses per model, maxtext uses a single class but with config

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative is to write a get_chat_parser() helper which attempts to load the chat parser corresponding to the model from Tunix and falls back to the default implementation if it is not found.

I would prefer to implement the MaxTextChatParser for now for simplicity and compatibility with the MaxText single-class + config model

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm i'm not sure if I'm following, maxtext implements the OSS models, why does the parser matter? say if maxtext uses qwen model, and the qwen chatparser is already there, couldn't we just use it?

return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)


def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmmm our parser already handles all these, if there's a missing model, it should be added to Tunix codebase instead

beta=trainer_config.rl.grpo_beta,
epsilon=trainer_config.rl.grpo_epsilon,
loss_algo=trainer_config.rl.loss_algo,
max_response_length=trainer_config.max_target_length - trainer_config.max_prefill_predict_length,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is max_target_length and max_prefill_predict_length? this looks like a confusing user facing knob

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree this is confusing - I will follow up with another PR to clean up the interface a bit.

max_target_length is the max generation length, while max_prefill_predict_length is the max size of prefill for prompts. We should rename max_target_length to something like max_tokens_to_generate and define the max model size as the sum of max_tokens_to_generate + max_prefill_predict_length

@andytwigg andytwigg self-requested a review April 3, 2026 16:30
def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
"""Helper to inject MaxText's system prompt into the input user messages."""
formatted_messages = []
for msg in messages:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these just looks like duplicated logics that already exist in our chat parser?

@copybara-service copybara-service Bot merged commit 1f04ad1 into main Apr 3, 2026
113 of 115 checks passed
@copybara-service copybara-service Bot deleted the nicogrande/async-rollouts branch April 3, 2026 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants