Adding support for agentic grpo trainer.#3540
Conversation
341832d to
6db2576
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
38196ba to
ec02199
Compare
856c176 to
fab1f04
Compare
a1faa01 to
aa94c87
Compare
richjames0
left a comment
There was a problem hiding this comment.
lgtm with a couple of concerns that you can ignore if not relevant but I do note Andy's has one unresolved comment
| return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule) | ||
|
|
||
|
|
||
| def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]: |
There was a problem hiding this comment.
is this also going to be performant enough?
There was a problem hiding this comment.
From looking at the implementation of other chat parsers in Tunix, I think we should be fine: https://github.com/google/tunix/blob/main/tunix/rl/agentic/parser/chat_template_parser/parser.py
My biggest concern with this change was just matching the code we were previously using for pre-processing and apply chat templates exactly to avoid unintended bugs related to this.
There was a problem hiding this comment.
hmmmmm our parser already handles all these, if there's a missing model, it should be added to Tunix codebase instead
aa94c87 to
daac9e0
Compare
| } | ||
|
|
||
|
|
||
| class MaxTextChatParser(agentic_chat_template_parser.DefaultChatTemplateParser): |
There was a problem hiding this comment.
why do you need this? it should just be qwen/gemma/llama parser
There was a problem hiding this comment.
this is needed because of the diffreence between how maxtext and tunix do parsing; tunix subclasses per model, maxtext uses a single class but with config
There was a problem hiding this comment.
The alternative is to write a get_chat_parser() helper which attempts to load the chat parser corresponding to the model from Tunix and falls back to the default implementation if it is not found.
I would prefer to implement the MaxTextChatParser for now for simplicity and compatibility with the MaxText single-class + config model
There was a problem hiding this comment.
hmmm i'm not sure if I'm following, maxtext implements the OSS models, why does the parser matter? say if maxtext uses qwen model, and the qwen chatparser is already there, couldn't we just use it?
| return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule) | ||
|
|
||
|
|
||
| def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]: |
There was a problem hiding this comment.
hmmmmm our parser already handles all these, if there's a missing model, it should be added to Tunix codebase instead
| beta=trainer_config.rl.grpo_beta, | ||
| epsilon=trainer_config.rl.grpo_epsilon, | ||
| loss_algo=trainer_config.rl.loss_algo, | ||
| max_response_length=trainer_config.max_target_length - trainer_config.max_prefill_predict_length, |
There was a problem hiding this comment.
what is max_target_length and max_prefill_predict_length? this looks like a confusing user facing knob
There was a problem hiding this comment.
Yes I agree this is confusing - I will follow up with another PR to clean up the interface a bit.
max_target_length is the max generation length, while max_prefill_predict_length is the max size of prefill for prompts. We should rename max_target_length to something like max_tokens_to_generate and define the max model size as the sum of max_tokens_to_generate + max_prefill_predict_length
| def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]: | ||
| """Helper to inject MaxText's system prompt into the input user messages.""" | ||
| formatted_messages = [] | ||
| for msg in messages: |
There was a problem hiding this comment.
these just looks like duplicated logics that already exist in our chat parser?
Description
Add support for the Tunix Agentic GRPO Learner, which enables asynchronous rollouts leveraging an online vLLM server.
To enable Agentic GRPO Learner, this PR introduces the
rl.use_agentic_rolloutflag. Similarly, the maximum amount of concurrency for the online vLLM server is set usingrl.max_concurrencyargument. Other arguments relevant to the Agentic GRPO Learner are also included in this PR.Tests
Standard GRPO for qwen3-0.6b on v6e: 523.08s
Agentic GRPO for qwen3-0.6b on v6e: 363.93s
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.