Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
JetStream Offline Engine for GRPO
This pull request introduces the JetStream Offline Engine
offline_engine.py
, an inference engine implemented to support sampling in GRPO, but also intended to be used for generic RL sampling and batch inference.Example Usage:
This Offline Engine works on both McJAX and Pathways. Currently it supports continuous batching and batch prefill. Support for multisampling is yet to be added.
Logical Axis Updates for supporting DP
This PR also introduces changes to support a new logical axis for supporting sequence parallel during Prefill. This is required for supporting Offline Engine with our version of data parallel where prefill shards sequence dimension and decode shards batch dimension.
PREFILL_LENGTH
toMaxText/common_types.py
to define prefill activation lengths.MaxText/configs/base.yml
andMaxText/configs/inference.yml
to includeprefill_activation_length
andprefill_activation_norm_length
.MaxText/layers/deepseek.py
,MaxText/layers/llama2.py
,MaxText/layers/gemma2.py
to dynamically adjust activation sharding based onMODEL_MODE_PREFILL
.MaxText/layers/attentions.py
to apply logical constraints specific to prefill mode for attention weights, masks, and outputs.MaxText/layers/embeddings.py
to acceptmodel_mode
as an argument for embedding operationsMaxText/layers/linears.py
to apply logical constraints specific to prefill modeRL specific changes
@jax.jit
tolog_prob_of_chosen_token
inMaxText/inference_utils.py
for improved performance.prefill_packing.py
to avoid pre-compiling the batch prefill function as it did not work on Pathways.prefill_packing.py
to conditionally return prompt token log probabilities. This is needed for GRPO loss calculation.Tests
Added
offline_engine_test.py
Checklist
Before submitting this PR, please make sure (put X in square brackets):