-
Notifications
You must be signed in to change notification settings - Fork 243
[tx] Support top_k sampling
#680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully implements top_k sampling, including the core logic and API integration. The changes are well-tested. I've provided a few suggestions to enhance performance and improve the readability and robustness of the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request implements top_k sampling. The changes look good and include necessary logic in the generation pipeline and new tests. I've provided a few suggestions to improve code maintainability and performance. Specifically, I've suggested refactoring duplicated test code, simplifying assertions, and optimizing the apply_top_k function for better performance.
skyrl-tx/tests/tinker/test_engine.py
Outdated
| # Values below threshold should be -inf | ||
| assert jnp.isinf(filtered[0]) and filtered[0] < 0 | ||
| assert jnp.isinf(filtered[1]) and filtered[1] < 0 | ||
| assert jnp.isinf(filtered[2]) and filtered[2] < 0 | ||
| # Top 2 values should be unchanged | ||
| assert filtered[3] == 4.0 | ||
| assert filtered[4] == 5.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertions to check the filtered logits can be simplified by comparing against an expected array. This makes the test more concise and easier to read.
| # Values below threshold should be -inf | |
| assert jnp.isinf(filtered[0]) and filtered[0] < 0 | |
| assert jnp.isinf(filtered[1]) and filtered[1] < 0 | |
| assert jnp.isinf(filtered[2]) and filtered[2] < 0 | |
| # Top 2 values should be unchanged | |
| assert filtered[3] == 4.0 | |
| assert filtered[4] == 5.0 | |
| expected = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, 4.0, 5.0]) | |
| assert jnp.array_equal(filtered, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto above
9bf5c08 to
bbc5056
Compare
|
@tyler-griggs lmk if you have any comments for top_k Rebased to most recent main |
bbc5056 to
1766982
Compare
|
@pcmoritz @tyler-griggs lmk if you have thoughts, happy to rebase again to most recent before you merge |
tyler-griggs
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @agolajko !!
skyrl-tx/tests/tinker/test_engine.py
Outdated
| # Values below threshold should be -inf | ||
| assert jnp.isinf(filtered[0]) and filtered[0] < 0 | ||
| assert jnp.isinf(filtered[1]) and filtered[1] < 0 | ||
| assert jnp.isinf(filtered[2]) and filtered[2] < 0 | ||
| # Top 2 values should be unchanged | ||
| assert filtered[3] == 4.0 | ||
| assert filtered[4] == 5.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto above
Resolve merge conflicts to combine: - top_k sampling support from feature branch - stop_strings support from main branch Both features are now available in SamplingParams. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Tyler Griggs <131809874+tyler-griggs@users.noreply.github.com>
…nto feature/implement-top-k
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully implements top_k sampling. The changes are well-structured, touching the API, type definitions, and the core generator logic. The implementation of apply_top_k_batch is efficient and JIT-friendly. The new tests in test_api.py and test_generator.py provide good coverage for the new functionality.
I have a few suggestions to improve the assertions in the tests to make them more robust and concise. Overall, this is a solid contribution.
pcmoritz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for implementing this @agolajko , I implemented @tyler-griggs 's suggestion of a fast path if there is no top_k filtering, and also used jax.lax.top_k so we don't need to do the sorting :)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
I'll merge this now, but there are some improvements that could be made as a follow up:
|
Implements the
top_kpart of the sampling API requested in #533Tests
test_top_k_filteringintest_generator.py: tests the core logic of the samplingtest_sample_top_kintest_api.py: checks the API can be called with thetop_kparameterDiscussed with @pcmoritz on slack