Skip to content

Conversation

@agolajko
Copy link
Contributor

@agolajko agolajko commented Nov 18, 2025

Implements the top_k part of the sampling API requested in #533

Tests

  1. test_top_k_filtering in test_generator.py: tests the core logic of the sampling
  2. test_sample_top_k in test_api.py: checks the API can be called with the top_k parameter

Discussed with @pcmoritz on slack

@agolajko
Copy link
Contributor Author

@gemini-code-assist review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully implements top_k sampling, including the core logic and API integration. The changes are well-tested. I've provided a few suggestions to enhance performance and improve the readability and robustness of the tests.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements top_k sampling. The changes look good and include necessary logic in the generation pipeline and new tests. I've provided a few suggestions to improve code maintainability and performance. Specifically, I've suggested refactoring duplicated test code, simplifying assertions, and optimizing the apply_top_k function for better performance.

Comment on lines 370 to 516
# Values below threshold should be -inf
assert jnp.isinf(filtered[0]) and filtered[0] < 0
assert jnp.isinf(filtered[1]) and filtered[1] < 0
assert jnp.isinf(filtered[2]) and filtered[2] < 0
# Top 2 values should be unchanged
assert filtered[3] == 4.0
assert filtered[4] == 5.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertions to check the filtered logits can be simplified by comparing against an expected array. This makes the test more concise and easier to read.

Suggested change
# Values below threshold should be -inf
assert jnp.isinf(filtered[0]) and filtered[0] < 0
assert jnp.isinf(filtered[1]) and filtered[1] < 0
assert jnp.isinf(filtered[2]) and filtered[2] < 0
# Top 2 values should be unchanged
assert filtered[3] == 4.0
assert filtered[4] == 5.0
expected = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, 4.0, 5.0])
assert jnp.array_equal(filtered, expected)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto above

@pcmoritz pcmoritz self-assigned this Nov 18, 2025
@pcmoritz pcmoritz added the tx label Nov 18, 2025
@tyler-griggs tyler-griggs changed the title [tx] top_k re #533 [tx] Support top_k sampling Nov 30, 2025
@agolajko agolajko force-pushed the feature/implement-top-k branch from 9bf5c08 to bbc5056 Compare December 1, 2025 18:23
@agolajko
Copy link
Contributor Author

agolajko commented Dec 1, 2025

@tyler-griggs lmk if you have any comments for top_k

Rebased to most recent main

@agolajko
Copy link
Contributor Author

agolajko commented Dec 8, 2025

@pcmoritz @tyler-griggs lmk if you have thoughts, happy to rebase again to most recent before you merge

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @agolajko !!

Comment on lines 370 to 516
# Values below threshold should be -inf
assert jnp.isinf(filtered[0]) and filtered[0] < 0
assert jnp.isinf(filtered[1]) and filtered[1] < 0
assert jnp.isinf(filtered[2]) and filtered[2] < 0
# Top 2 values should be unchanged
assert filtered[3] == 4.0
assert filtered[4] == 5.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto above

pcmoritz and others added 8 commits December 29, 2025 11:01
Resolve merge conflicts to combine:
- top_k sampling support from feature branch
- stop_strings support from main branch

Both features are now available in SamplingParams.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Tyler Griggs <131809874+tyler-griggs@users.noreply.github.com>
@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully implements top_k sampling. The changes are well-structured, touching the API, type definitions, and the core generator logic. The implementation of apply_top_k_batch is efficient and JIT-friendly. The new tests in test_api.py and test_generator.py provide good coverage for the new functionality.

I have a few suggestions to improve the assertions in the tests to make them more robust and concise. Overall, this is a solid contribution.

Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for implementing this @agolajko , I implemented @tyler-griggs 's suggestion of a fast path if there is no top_k filtering, and also used jax.lax.top_k so we don't need to do the sorting :)

pcmoritz and others added 6 commits December 29, 2025 12:00
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@pcmoritz
Copy link
Collaborator

I'll merge this now, but there are some improvements that could be made as a follow up:

  • Bin max_top_k to avoid JIT recompilation (and we could also restrict top_k to be smaller than some reasonable value)
  • Do sampling only among the top_k / max_top_k values, this will be more performant but requires reorganizing the code

@pcmoritz pcmoritz merged commit 465c3b6 into NovaSky-AI:main Dec 29, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants