Skip to content

Update MaxEngine to optionally return token log probability#1626

Merged
copybara-service[bot] merged 1 commit into
mainfrom
maxtext_logp
Apr 29, 2025
Merged

Update MaxEngine to optionally return token log probability#1626
copybara-service[bot] merged 1 commit into
mainfrom
maxtext_logp

Conversation

@wenxindongwork
Copy link
Copy Markdown
Collaborator

@wenxindongwork wenxindongwork commented Apr 24, 2025

Description

This PR and AI-Hypercomputer/JetStream#260 together enables MaxEngine to return token log probability when return_log_prob is set to True in the global config.

This change is needed for calculating GRPO loss.

Tests

Manually tested

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Copy Markdown
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why the gpu_unit_tests are failing with this change?

FAILED MaxText/tests/decode_tests.py::DecodeTests::test_gpu_base - TypeError: ResultTokens.__init__() got an unexpected keyword argument 'log_prob'
FAILED MaxText/tests/decode_tests.py::DecodeTests::test_gpu_int8 - TypeError: ResultTokens.__init__() got an unexpected keyword argument 'log_prob'
FAILED MaxText/tests/decode_tests.py::DecodeTests::test_gpu_pdb_lt_1 - TypeError: ResultTokens.__init__() got an unexpected keyword argument 'log_prob'
FAILED MaxText/tests/maxengine_test.py::MaxEngineTest::test_basic_prefill - TypeError: ResultTokens.__init__() got an unexpected keyword argument 'log_prob'

@wenxindongwork
Copy link
Copy Markdown
Collaborator Author

Do you know why the gpu_unit_tests are failing with this change?

FAILED MaxText/tests/decode_tests.py::DecodeTests::test_gpu_base - TypeError: ResultTokens.__init__() got an unexpected keyword argument 'log_prob'
FAILED MaxText/tests/decode_tests.py::DecodeTests::test_gpu_int8 - TypeError: ResultTokens.__init__() got an unexpected keyword argument 'log_prob'
FAILED MaxText/tests/decode_tests.py::DecodeTests::test_gpu_pdb_lt_1 - TypeError: ResultTokens.__init__() got an unexpected keyword argument 'log_prob'
FAILED MaxText/tests/maxengine_test.py::MaxEngineTest::test_basic_prefill - TypeError: ResultTokens.__init__() got an unexpected keyword argument 'log_prob'

The Jetstream PR has to go through first, since it updates the ResultTokens API.

@wenxindongwork wenxindongwork marked this pull request as draft April 24, 2025 20:38
Comment thread MaxText/inference_utils.py Outdated
@vipannalla
Copy link
Copy Markdown
Collaborator

This PR includes changes from many other PRs which are already landed and its hard to read the specific code chagnes. Can you rebase your branch and update this PR?

Copy link
Copy Markdown
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good

@copybara-service copybara-service Bot merged commit 2e4af93 into main Apr 29, 2025
18 checks passed
@copybara-service copybara-service Bot deleted the maxtext_logp branch April 29, 2025 18:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants