Skip to content

[tx] Some simplifications and speedups of sampling#717

Merged
pcmoritz merged 28 commits intoNovaSky-AI:mainfrom
pcmoritz:tx-attention-mask
Dec 6, 2025
Merged

[tx] Some simplifications and speedups of sampling#717
pcmoritz merged 28 commits intoNovaSky-AI:mainfrom
pcmoritz:tx-attention-mask

Conversation

@pcmoritz
Copy link
Collaborator

Implemented this while travelling, still needs to be cleaned up. It puts the whole prefill + decoding into one jit, avoids recomputation of the logprobs and removes passing around some expensive state.

@pcmoritz pcmoritz added the tx label Nov 28, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant simplifications and performance improvements to the sampling process by fusing the prefill and decoding steps into a single JIT-compiled function. This is an excellent change that also improves code clarity by using jax.lax.scan more idiomatically to accumulate outputs and making the DecodeState more lightweight. The changes look solid, and I have a couple of suggestions to further enhance performance and numerical stability.

Comment on lines +209 to +210
zero_temp_mask = (temperatures == 0.0)[:, None]
inv_temperatures = (1.0 / jnp.maximum(temperatures, 1e-10))[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When temperature is 0, inv_temperatures becomes a very large number (1e10) due to jnp.maximum(temperatures, 1e-10). Although the result of sampling with this large value is correctly masked out later in the greedy path, the intermediate computation log_probs * s.inv_temperatures still occurs. This can lead to numerical instability (e.g., producing inf or nan) and is an unnecessary computation. It would be safer and cleaner to avoid the division by a small number for the T=0 case altogether.

Suggested change
zero_temp_mask = (temperatures == 0.0)[:, None]
inv_temperatures = (1.0 / jnp.maximum(temperatures, 1e-10))[:, None]
zero_temp_mask = (temperatures == 0.0)[:, None]
inv_temperatures = 1.0 / jnp.where(zero_temp_mask, 1.0, temperatures[:, None])

@pcmoritz
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors the generation logic by consolidating the prefill and decoding steps into a single JIT-compiled function, _prefill_and_decode. This is a great improvement for both performance and code clarity. The DecodeState is simplified, and several helper functions have been cleanly integrated into the main loop. The use of jax.lax.scan to accumulate outputs is more efficient than the previous approach of updating arrays in each step. The overall changes make the code more idiomatic and easier to follow. I have one minor suggestion to further optimize the decode loop by hoisting a constant computation.

Comment on lines +105 to +111
def decode_fn(s: DecodeState, _) -> tuple[DecodeState, tuple[jax.Array, jax.Array, jax.Array]]:
"""Decode one token step. Returns (state, (token, logprob, is_stop)) for scan accumulation."""
# Sample next token
split_keys = jax.vmap(jax.random.split)(s.rngs)
rngs, sample_keys = split_keys[:, 0], split_keys[:, 1]

zero_temp_mask = temperatures == 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For a minor optimization, zero_temp_mask can be computed once outside the decode_fn since temperatures is constant within the scan loop. This avoids redundant computation in each decoding step.

Suggested change
def decode_fn(s: DecodeState, _) -> tuple[DecodeState, tuple[jax.Array, jax.Array, jax.Array]]:
"""Decode one token step. Returns (state, (token, logprob, is_stop)) for scan accumulation."""
# Sample next token
split_keys = jax.vmap(jax.random.split)(s.rngs)
rngs, sample_keys = split_keys[:, 0], split_keys[:, 1]
zero_temp_mask = temperatures == 0.0
zero_temp_mask = temperatures == 0.0
def decode_fn(s: DecodeState, _) -> tuple[DecodeState, tuple[jax.Array, jax.Array, jax.Array]]:
"""Decode one token step. Returns (state, (token, logprob, is_stop)) for scan accumulation."""
# Sample next token
split_keys = jax.vmap(jax.random.split)(s.rngs)
rngs, sample_keys = split_keys[:, 0], split_keys[:, 1]

@pcmoritz pcmoritz changed the title [tx] [WIP] Some simplifications and speedups of sampling [tx] Some simplifications and speedups of sampling Dec 1, 2025
Comment on lines +209 to +211
has_stop,
first_stop_idx + 1, # Include the stop token
jnp.array([sp.max_tokens for sp in sampling_params]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly how per-request max_tokens are handled here, it seems possible to violate this maximum bound if different requests in the batch have different max_tokens. Or is this a known current limitation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, some of the sequences stop, but they stop after their respective max_tokens, and then won't get truncated. You are right this is indeed a bug, let me fix it :)

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks great, very nice!

pcmoritz and others added 2 commits December 4, 2025 01:32
Resolved merge conflict in tx/utils/generator.py by:
- Keeping the HEAD branch's simpler _prefill_and_decode JIT pattern with
  local decode_fn capturing closure variables
- Adding upstream's prompt_logprobs feature (compute_prompt_logprobs function,
  prompt_logprobs parameter, and GenerateOutput.prompt_logprobs field)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Dec 5, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable refactoring of the text generation logic. By consolidating the prefill and decoding steps into a single JIT-compiled function, you've made a change that should substantially improve performance by reducing compilation overhead and enabling more effective compiler optimizations. The simplification of the DecodeState and the use of jax.lax.scan to accumulate outputs rather than passing large, mutable state arrays is a great improvement for both code clarity and efficiency. The test suite has also been thoughtfully expanded to cover an important edge case regarding max_tokens and stop sequences. Overall, this is a high-quality contribution. I've added a couple of minor suggestions to further enhance code conciseness, in line with your comment about further cleanup.

Comment on lines +171 to +173
new_tokens = jnp.swapaxes(tokens_stacked, 0, 1).squeeze(-1)
new_logprobs = jnp.swapaxes(logprobs_stacked, 0, 1).squeeze(-1)
is_stop = jnp.swapaxes(is_stop_stacked, 0, 1).squeeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve conciseness, you can use a list comprehension or a generator expression to apply the swapaxes and squeeze operations to all three tensors at once. This avoids repeating the same logic and makes the code more compact.

Suggested change
new_tokens = jnp.swapaxes(tokens_stacked, 0, 1).squeeze(-1)
new_logprobs = jnp.swapaxes(logprobs_stacked, 0, 1).squeeze(-1)
is_stop = jnp.swapaxes(is_stop_stacked, 0, 1).squeeze(-1)
new_tokens, new_logprobs, is_stop = (jnp.swapaxes(tensor, 0, 1).squeeze(-1) for tensor in (tokens_stacked, logprobs_stacked, is_stop_stacked))

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Dec 5, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant simplifications and performance improvements to the sampling process by consolidating the prefill and decoding steps into a single JIT-compiled function. The changes result in cleaner code, removing the need to pass around extensive state and avoiding potential recomputations. The updated test cases effectively validate the new logic, including edge cases for stopping conditions. My review includes one suggestion to enhance the readability of the sampling logic, aligning with your goal of cleaning up the implementation.

Comment on lines +127 to +133
zero_temp_mask = temperatures == 0.0
scaled_logits = s.logits / jnp.where(zero_temp_mask, 1.0, temperatures)[:, None]
sampled = jax.vmap(lambda key, logit: jax.random.categorical(key, logit, axis=-1))(
sample_keys, scaled_logits
)
greedy = jnp.argmax(s.logits, axis=-1)
next_token = jnp.where(zero_temp_mask[:, None], greedy[:, None], sampled[:, None])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sampling logic is functionally correct but could be restructured for better readability, especially given your comment about needing to clean up the code. The current implementation computes sampled tokens for all requests, including zero-temperature ones where they are immediately discarded. While this pattern is common in JAX for vectorization, reordering the operations to compute greedy tokens first and using a more direct jnp.where at the end can improve clarity.

Suggested change
zero_temp_mask = temperatures == 0.0
scaled_logits = s.logits / jnp.where(zero_temp_mask, 1.0, temperatures)[:, None]
sampled = jax.vmap(lambda key, logit: jax.random.categorical(key, logit, axis=-1))(
sample_keys, scaled_logits
)
greedy = jnp.argmax(s.logits, axis=-1)
next_token = jnp.where(zero_temp_mask[:, None], greedy[:, None], sampled[:, None])
zero_temp_mask = temperatures == 0.0
greedy_tokens = jnp.argmax(s.logits, axis=-1)
scaled_logits = s.logits / jnp.where(zero_temp_mask, 1.0, temperatures)[:, None]
sampled_tokens = jax.vmap(lambda key, logit: jax.random.categorical(key, logit, axis=-1))(
sample_keys, scaled_logits
)
next_token = jnp.where(zero_temp_mask, greedy_tokens, sampled_tokens)[:, None]

@pcmoritz pcmoritz merged commit edfa55d into NovaSky-AI:main Dec 6, 2025
4 of 5 checks passed
dzorlu pushed a commit to fleet-ai/SkyRL that referenced this pull request Feb 4, 2026
Implemented this while travelling, still needs to be cleaned up. It puts
the whole prefill + decoding into one jit, avoids recomputation of the
logprobs and removes passing around some expensive state.

---------

Co-authored-by: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants