Skip to content

Conversation

@wenxindongwork
Copy link
Collaborator

@wenxindongwork wenxindongwork commented Mar 17, 2025

This PR integrates MaxEngine with Kithara's MaxTextModel to improve inference performance. Previously, model inference (i.e. MaxTextModel.generate()) is supported via a naive autoregressive for-loop without KV Cache, which results in slow tokens per second for longer sequences. Now, inference for MaxTextModel is backed by a KV Cache system.

Fast inference is required for supporting online and offline evaluation, as well as on-policy RLHF.

It should also be mentioned that Kithara's KerasHubModel.generate() is already backed by KV Cache.

Key Changes

  • Added JetStream as a submodule
  • Patched MaxText's MaxEngine to intake an parameterized model
  • Redesigned the model generation API to support a wider range of input formats:
    • String inputs (single or batched)
    • Token inputs as integer lists or numpy arrays
  • Added configurable max_prefill_predict_length parameter. This parameter determines the max prefill length.
  • Updated documentation
  • Support batched inference with progress tracking

Testing

  • Added unit tests verifying compatibility with MaxText models

Copy link
Collaborator

@lienchen0526 lienchen0526 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some python readability stuff here. We can discuss which can be a good practice from both of our sides

@wenxindongwork
Copy link
Collaborator Author

thank you Jerry for the thorough code review! And thanks for catching the typo :)

@wenxindongwork wenxindongwork merged commit 36686dc into main Mar 20, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants