[tx] Some simplifications and speedups of sampling#717
[tx] Some simplifications and speedups of sampling#717pcmoritz merged 28 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces significant simplifications and performance improvements to the sampling process by fusing the prefill and decoding steps into a single JIT-compiled function. This is an excellent change that also improves code clarity by using jax.lax.scan more idiomatically to accumulate outputs and making the DecodeState more lightweight. The changes look solid, and I have a couple of suggestions to further enhance performance and numerical stability.
skyrl-tx/tx/utils/generator.py
Outdated
| zero_temp_mask = (temperatures == 0.0)[:, None] | ||
| inv_temperatures = (1.0 / jnp.maximum(temperatures, 1e-10))[:, None] |
There was a problem hiding this comment.
When temperature is 0, inv_temperatures becomes a very large number (1e10) due to jnp.maximum(temperatures, 1e-10). Although the result of sampling with this large value is correctly masked out later in the greedy path, the intermediate computation log_probs * s.inv_temperatures still occurs. This can lead to numerical instability (e.g., producing inf or nan) and is an unnecessary computation. It would be safer and cleaner to avoid the division by a small number for the T=0 case altogether.
| zero_temp_mask = (temperatures == 0.0)[:, None] | |
| inv_temperatures = (1.0 / jnp.maximum(temperatures, 1e-10))[:, None] | |
| zero_temp_mask = (temperatures == 0.0)[:, None] | |
| inv_temperatures = 1.0 / jnp.where(zero_temp_mask, 1.0, temperatures[:, None]) |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request significantly refactors the generation logic by consolidating the prefill and decoding steps into a single JIT-compiled function, _prefill_and_decode. This is a great improvement for both performance and code clarity. The DecodeState is simplified, and several helper functions have been cleanly integrated into the main loop. The use of jax.lax.scan to accumulate outputs is more efficient than the previous approach of updating arrays in each step. The overall changes make the code more idiomatic and easier to follow. I have one minor suggestion to further optimize the decode loop by hoisting a constant computation.
skyrl-tx/tx/utils/generator.py
Outdated
| def decode_fn(s: DecodeState, _) -> tuple[DecodeState, tuple[jax.Array, jax.Array, jax.Array]]: | ||
| """Decode one token step. Returns (state, (token, logprob, is_stop)) for scan accumulation.""" | ||
| # Sample next token | ||
| split_keys = jax.vmap(jax.random.split)(s.rngs) | ||
| rngs, sample_keys = split_keys[:, 0], split_keys[:, 1] | ||
|
|
||
| zero_temp_mask = temperatures == 0.0 |
There was a problem hiding this comment.
For a minor optimization, zero_temp_mask can be computed once outside the decode_fn since temperatures is constant within the scan loop. This avoids redundant computation in each decoding step.
| def decode_fn(s: DecodeState, _) -> tuple[DecodeState, tuple[jax.Array, jax.Array, jax.Array]]: | |
| """Decode one token step. Returns (state, (token, logprob, is_stop)) for scan accumulation.""" | |
| # Sample next token | |
| split_keys = jax.vmap(jax.random.split)(s.rngs) | |
| rngs, sample_keys = split_keys[:, 0], split_keys[:, 1] | |
| zero_temp_mask = temperatures == 0.0 | |
| zero_temp_mask = temperatures == 0.0 | |
| def decode_fn(s: DecodeState, _) -> tuple[DecodeState, tuple[jax.Array, jax.Array, jax.Array]]: | |
| """Decode one token step. Returns (state, (token, logprob, is_stop)) for scan accumulation.""" | |
| # Sample next token | |
| split_keys = jax.vmap(jax.random.split)(s.rngs) | |
| rngs, sample_keys = split_keys[:, 0], split_keys[:, 1] |
skyrl-tx/tx/utils/generator.py
Outdated
| has_stop, | ||
| first_stop_idx + 1, # Include the stop token | ||
| jnp.array([sp.max_tokens for sp in sampling_params]), |
There was a problem hiding this comment.
If I understand correctly how per-request max_tokens are handled here, it seems possible to violate this maximum bound if different requests in the batch have different max_tokens. Or is this a known current limitation?
There was a problem hiding this comment.
Ah I see, some of the sequences stop, but they stop after their respective max_tokens, and then won't get truncated. You are right this is indeed a bug, let me fix it :)
tyler-griggs
left a comment
There was a problem hiding this comment.
Overall looks great, very nice!
Resolved merge conflict in tx/utils/generator.py by: - Keeping the HEAD branch's simpler _prefill_and_decode JIT pattern with local decode_fn capturing closure variables - Adding upstream's prompt_logprobs feature (compute_prompt_logprobs function, prompt_logprobs parameter, and GenerateOutput.prompt_logprobs field) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and valuable refactoring of the text generation logic. By consolidating the prefill and decoding steps into a single JIT-compiled function, you've made a change that should substantially improve performance by reducing compilation overhead and enabling more effective compiler optimizations. The simplification of the DecodeState and the use of jax.lax.scan to accumulate outputs rather than passing large, mutable state arrays is a great improvement for both code clarity and efficiency. The test suite has also been thoughtfully expanded to cover an important edge case regarding max_tokens and stop sequences. Overall, this is a high-quality contribution. I've added a couple of minor suggestions to further enhance code conciseness, in line with your comment about further cleanup.
skyrl-tx/tx/utils/generator.py
Outdated
| new_tokens = jnp.swapaxes(tokens_stacked, 0, 1).squeeze(-1) | ||
| new_logprobs = jnp.swapaxes(logprobs_stacked, 0, 1).squeeze(-1) | ||
| is_stop = jnp.swapaxes(is_stop_stacked, 0, 1).squeeze(-1) |
There was a problem hiding this comment.
To improve conciseness, you can use a list comprehension or a generator expression to apply the swapaxes and squeeze operations to all three tensors at once. This avoids repeating the same logic and makes the code more compact.
| new_tokens = jnp.swapaxes(tokens_stacked, 0, 1).squeeze(-1) | |
| new_logprobs = jnp.swapaxes(logprobs_stacked, 0, 1).squeeze(-1) | |
| is_stop = jnp.swapaxes(is_stop_stacked, 0, 1).squeeze(-1) | |
| new_tokens, new_logprobs, is_stop = (jnp.swapaxes(tensor, 0, 1).squeeze(-1) for tensor in (tokens_stacked, logprobs_stacked, is_stop_stacked)) |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces significant simplifications and performance improvements to the sampling process by consolidating the prefill and decoding steps into a single JIT-compiled function. The changes result in cleaner code, removing the need to pass around extensive state and avoiding potential recomputations. The updated test cases effectively validate the new logic, including edge cases for stopping conditions. My review includes one suggestion to enhance the readability of the sampling logic, aligning with your goal of cleaning up the implementation.
| zero_temp_mask = temperatures == 0.0 | ||
| scaled_logits = s.logits / jnp.where(zero_temp_mask, 1.0, temperatures)[:, None] | ||
| sampled = jax.vmap(lambda key, logit: jax.random.categorical(key, logit, axis=-1))( | ||
| sample_keys, scaled_logits | ||
| ) | ||
| greedy = jnp.argmax(s.logits, axis=-1) | ||
| next_token = jnp.where(zero_temp_mask[:, None], greedy[:, None], sampled[:, None]) |
There was a problem hiding this comment.
The sampling logic is functionally correct but could be restructured for better readability, especially given your comment about needing to clean up the code. The current implementation computes sampled tokens for all requests, including zero-temperature ones where they are immediately discarded. While this pattern is common in JAX for vectorization, reordering the operations to compute greedy tokens first and using a more direct jnp.where at the end can improve clarity.
| zero_temp_mask = temperatures == 0.0 | |
| scaled_logits = s.logits / jnp.where(zero_temp_mask, 1.0, temperatures)[:, None] | |
| sampled = jax.vmap(lambda key, logit: jax.random.categorical(key, logit, axis=-1))( | |
| sample_keys, scaled_logits | |
| ) | |
| greedy = jnp.argmax(s.logits, axis=-1) | |
| next_token = jnp.where(zero_temp_mask[:, None], greedy[:, None], sampled[:, None]) | |
| zero_temp_mask = temperatures == 0.0 | |
| greedy_tokens = jnp.argmax(s.logits, axis=-1) | |
| scaled_logits = s.logits / jnp.where(zero_temp_mask, 1.0, temperatures)[:, None] | |
| sampled_tokens = jax.vmap(lambda key, logit: jax.random.categorical(key, logit, axis=-1))( | |
| sample_keys, scaled_logits | |
| ) | |
| next_token = jnp.where(zero_temp_mask, greedy_tokens, sampled_tokens)[:, None] |
Implemented this while travelling, still needs to be cleaned up. It puts the whole prefill + decoding into one jit, avoids recomputation of the logprobs and removes passing around some expensive state. --------- Co-authored-by: Claude <noreply@anthropic.com>
Implemented this while travelling, still needs to be cleaned up. It puts the whole prefill + decoding into one jit, avoids recomputation of the logprobs and removes passing around some expensive state.