Skip to content

JetStream Offline Engine #1829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

wenxindongwork
Copy link
Collaborator

JetStream Offline Engine for GRPO

This pull request introduces the JetStream Offline Engine offline_engine.py, an inference engine implemented to support sampling in GRPO, but also intended to be used for generic RL sampling and batch inference.

Example Usage:

offline_engine = OfflineEngine(
    config=maxtext_config,
    params=None,
    enable_batch_prefill=True,
)

input_data = [
    jax.numpy.arange(80),
    jax.numpy.arange(90),
    jax.numpy.arange(100),
]

results = offline_engine.batch_inference(input_data)

for completion_output in results:
    text = offline_engine.tokenizer.decode(completion_output.token_ids)
    max_logging.log(f"Output: {text}")

This Offline Engine works on both McJAX and Pathways. Currently it supports continuous batching and batch prefill. Support for multisampling is yet to be added.

Logical Axis Updates for supporting DP

This PR also introduces changes to support a new logical axis for supporting sequence parallel during Prefill. This is required for supporting Offline Engine with our version of data parallel where prefill shards sequence dimension and decode shards batch dimension.

  • Added PREFILL_LENGTH to MaxText/common_types.py to define prefill activation lengths.
  • Updated logical axis rules in MaxText/configs/base.yml and MaxText/configs/inference.yml to include prefill_activation_length and prefill_activation_norm_length.
  • Modified logical constraints in MaxText/layers/deepseek.py, MaxText/layers/llama2.py , MaxText/layers/gemma2.py to dynamically adjust activation sharding based on MODEL_MODE_PREFILL.
  • Modified MaxText/layers/attentions.py to apply logical constraints specific to prefill mode for attention weights, masks, and outputs.
  • Modified MaxText/layers/embeddings.py to accept model_mode as an argument for embedding operations
  • Modified logical constraints in MaxText/layers/linears.py to apply logical constraints specific to prefill mode

RL specific changes

  • Added @jax.jit to log_prob_of_chosen_token in MaxText/inference_utils.py for improved performance.
  • Added logic in prefill_packing.py to avoid pre-compiling the batch prefill function as it did not work on Pathways.
  • Added logic in prefill_packing.py to conditionally return prompt token log probabilities. This is needed for GRPO loss calculation.

Tests

Added offline_engine_test.py

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

squash

wip

add rng logic

clean up branch

remove duplicate logprob logic

code clean up

code clean up

fix

debug

code clean up

wip

return concatenated

return prompt tokens and prompt log prob

debug.py

deepseek

seq parallel with offline engine. clean up

inference replica parallelism

grpo offpolicy

rescan changes

grpo with pathways

wip

return numpy arrays

wip

wip. decode no gaps. prefill still recompiling

wip

wip

wip

wip

wip

wip

debug file

revert previous commit. use max logging.

reuse params speed up initialization

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

wip

works but very slow with pathways

wip

wip

formatting

WIP DP works, need to verify correctness

WIP

wip

code style

wip

clean up

wip

wip

wip
@wenxindongwork wenxindongwork force-pushed the offline-engine-seq-parallel-fork branch from ca1d82d to db9613e Compare June 13, 2025 03:13
@khatwanimohit
Copy link
Collaborator

@vipannalla @mitalisi Can you help review this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants