Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Llama3 models, including a new implementation in tx.models.llama3, along with comprehensive unit and integration tests. A significant improvement is the refactoring of common layers like RMSNorm and SwiGLUMLP into a shared tx.layers.common module, which cleans up the existing Qwen3 model implementation and promotes code reuse.
My review has identified two main issues. First, a critical bug is introduced by the change to allow string-based stop sequences in the sampling parameters. The API is updated, but the backend generation logic is not, which will cause runtime failures. Second, there is a high-severity bug in the is_lora_param method within the new Llama3ForCausalLM model that prevents LoRA parameters from being identified for training, rendering LoRA fine-tuning ineffective.
Apart from these issues, the overall implementation and testing strategy are excellent.
skyrl-tx/tx/tinker/api.py
Outdated
| max_tokens: int | None = None | ||
| seed: int | None = None | ||
| stop: Sequence[int] | None = None | ||
| stop: Sequence[int] | Sequence[str] | None = None |
There was a problem hiding this comment.
The stop parameter in SamplingParams now accepts Sequence[str], but the generation logic is not updated to handle string-based stop sequences. The current implementation in tx.utils.generator.GeneratorMixin.generate expects a list of integer token IDs and will raise an error when it tries to convert a list containing strings into a jnp.array of integers. This will lead to a runtime crash for any request that uses string stop sequences.
To fix this, the string stop sequences need to be tokenized before being used in the generation loop. This likely requires passing the tokenizer to the generation function or handling tokenization within the API layer before creating the generation request.
|
@pcmoritz , any thoughts about this PR? |
|
@tyler-griggs, got thoughts about this PR? |
tyler-griggs
left a comment
There was a problem hiding this comment.
Thanks for writing this up @atemaguer!
skyrl-tx/tx/tinker/api.py
Outdated
| max_tokens: int | None = None | ||
| seed: int | None = None | ||
| stop: Sequence[int] | None = None | ||
| stop: Sequence[int] | Sequence[str] | None = None |
There was a problem hiding this comment.
I'm curious why these updates to stop were needed -- could you please explain?
There was a problem hiding this comment.
I think it's because the Tinker client API update supports providing stop sequences as well. Also, some of the examples were failing, especially when using Llama 3 as the base model.
There was a problem hiding this comment.
It makes sense that it should be updated if the Tinker client API should support stop sequences, but I don't think our generation logic actually supports stop sequences right now. It seems like we should push this set of changes into a later PR when we actually add support for stop strings sequences rather than just stop tokens.
There was a problem hiding this comment.
Based on the observation you made that Qwen3 and Llama3 are very similar, could we simplify this even further such that Qwen3 inherits from Llama3 and adds fairly light-weight additional logic? E.g., the transformers library takes this pattern: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modular_qwen3.py
We could likely further reduce the duplicate code between the two models, which will make it easier to keep them consistent going forward. But let me know if this actually doesn't make sense.
|
@tyler-griggs , let me know what you think of these new updates |
skyrl-tx/tx/models/llama3.py
Outdated
| self.config = config | ||
|
|
||
| # Token embeddings | ||
| self.embed_tokens = nnx.Embed( |
There was a problem hiding this comment.
I assume we'd like to have LoRA in the embedding layer for Llama3 (as we do in Qwen3). Is the plan to do this in a follow up PR?
skyrl-tx/tx/models/qwen3.py
Outdated
| # Get rope_theta from config | ||
| from tx.layers.common import apply_rope | ||
|
|
||
| rope_theta = getattr(self.config, "rope_theta", 500000.0) |
There was a problem hiding this comment.
Ideally we would not set default parameters in the code like this, but rather inherit the defaults from the base config itself. Is there a reason to use getattr here or can we avoid it?
skyrl-tx/tx/tinker/api.py
Outdated
| max_tokens: int | None = None | ||
| seed: int | None = None | ||
| stop: Sequence[int] | None = None | ||
| stop: Sequence[int] | Sequence[str] | None = None |
There was a problem hiding this comment.
It makes sense that it should be updated if the Tinker client API should support stop sequences, but I don't think our generation logic actually supports stop sequences right now. It seems like we should push this set of changes into a later PR when we actually add support for stop strings sequences rather than just stop tokens.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request successfully adds support for Llama3 models and refactors the Qwen3 implementation to share common layers, which is a great step towards code reuse. The new Llama3 model and its associated tests are well-implemented. My review primarily focuses on improving the inheritance structure between the Llama3 and Qwen3 models. Currently, several Qwen3 classes inherit from their Llama3 counterparts but completely override methods with nearly identical code, leading to significant code duplication. I've provided suggestions to refactor the Llama3 base classes to be more extensible, which will reduce duplication and improve long-term maintainability. I also noted a minor issue with a local import.
skyrl-tx/tx/models/qwen3.py
Outdated
|
|
||
|
|
||
| class Qwen3Attention(nnx.Module): | ||
| class Qwen3Attention(Llama3Attention): |
There was a problem hiding this comment.
While it's great that Qwen3Attention inherits from Llama3Attention, the __call__ method is almost a complete copy of the parent's implementation, with the only addition being the QK-Norm step. This code duplication makes maintenance harder.
Consider refactoring Llama3Attention.__call__ to include a hook for processing Q and K tensors, which can be overridden by subclasses. This would allow Qwen3Attention to simply implement the QK-Norm logic without duplicating the entire attention mechanism.
For example, in Llama3Attention:
def _process_q_k(self, q: jax.Array, k: jax.Array) -> tuple[jax.Array, jax.Array]:
# No-op for Llama3
return q, k
def __call__(self, ...):
# ... project q, k, v ...
q, k = self._process_q_k(q, k)
# ... apply RoPE, attention, etc. ...Then, Qwen3Attention can just override _process_q_k and inherit __call__:
class Qwen3Attention(Llama3Attention):
# ... __init__ ...
def _process_q_k(self, q: jax.Array, k: jax.Array) -> tuple[jax.Array, jax.Array]:
q = self.q_norm(q)
k = self.k_norm(k)
return q, kSince llama3.py is part of this PR, this refactoring should be straightforward to apply.
skyrl-tx/tx/models/qwen3.py
Outdated
|
|
||
|
|
||
| class Qwen3Model(nnx.Module): | ||
| class Qwen3Model(Llama3Model): |
There was a problem hiding this comment.
Qwen3Model inherits from Llama3Model but completely overrides both __init__ and __call__ without calling super(). This makes the inheritance misleading and introduces significant code duplication, which will be a maintenance burden. The only differences are the types of the embedding and decoder layers, and the arguments passed to embed_tokens.
A better approach would be to make Llama3Model more generic so that Qwen3Model can reuse its structure. For example, you could parameterize the layer classes in Llama3Model.__init__:
# In Llama3Model
def __init__(self, config, *, dtype, rngs, embedding_cls=nnx.Embed, decoder_layer_cls=Llama3DecoderLayer, **embedding_kwargs):
...
self.embed_tokens = embedding_cls(..., **embedding_kwargs)
self.layers = nnx.List([decoder_layer_cls(...)])
...This would simplify Qwen3Model to:
class Qwen3Model(Llama3Model):
def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None:
# Pass Qwen3-specific classes and LoRA-related kwargs for LoRAEmbed
super().__init__(config, dtype=dtype, rngs=rngs,
embedding_cls=LoRAEmbed,
decoder_layer_cls=Qwen3DecoderLayer,
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank)A similar refactoring could be applied to __call__ to handle the different arguments for embed_tokens.
skyrl-tx/tx/models/qwen3.py
Outdated
|
|
||
|
|
||
| class Qwen3ForCausalLM(nnx.Module, GeneratorMixin): | ||
| class Qwen3ForCausalLM(Llama3ForCausalLM): |
There was a problem hiding this comment.
Similar to Qwen3Model, Qwen3ForCausalLM inherits from Llama3ForCausalLM but overrides __init__ and __call__ with almost identical code, which is not ideal for maintainability. The key differences are the underlying model class (Qwen3Model vs Llama3Model) and the lm_head implementation (LoRALinear vs nnx.Linear).
To improve this, you could refactor Llama3ForCausalLM to be more configurable, allowing subclasses to specify the model class and lm_head implementation. This would allow Qwen3ForCausalLM to inherit the core logic and only define what's different, greatly reducing code duplication.
skyrl-tx/tx/models/qwen3.py
Outdated
| k = self.k_norm(k) | ||
|
|
||
| # Get rope_theta from config | ||
| from tx.layers.common import apply_rope |
|
@tyler-griggs , let me know what you think about these new updates |
|
Btw @pcmoritz, it appears like this PR will cause conflicting changes with the FSDP support changes. Any chance you could review and merge this PR and then add the FSDP changes after? |
| @@ -1,8 +1,7 @@ | |||
| """Background engine for processing training requests.""" | |||
|
|
|||
| import argparse | |||
There was a problem hiding this comment.
I'm a little suprised about the diff in this file, maybe you need to merge main into the PR? These changes should already be on the current master :)
|
@pcmoritz, I merged the changes from main. let me know if there's anything else that needs changing or improving |
|
Thanks a lot for all the work, I'll get this PR merged next :) I'm planning do a little bit of restructuring, it is probably easiest to just edit it directly :) |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for Llama3 models, including LoRA training and generation. The changes are well-structured, introducing a generic ModelConfig and moving common layers like RMSNorm and apply_rope into separate files for reuse. New tests for the Llama3 model and its LoRA training capabilities are also included, which is great.
I've found one critical issue in the new Llama3Attention implementation related to batched decoding that needs to be addressed. This same issue appears to exist in the Qwen3Attention model as well and should be fixed there too. I also have a minor suggestion for code cleanup in one of the new test files. Overall, this is a solid contribution.
| k = jax.lax.dynamic_update_slice(k_cache, k, (0, cache_position, 0, 0)) | ||
| v = jax.lax.dynamic_update_slice(v_cache, v, (0, cache_position, 0, 0)) |
There was a problem hiding this comment.
The use of jax.lax.dynamic_update_slice with hardcoded start indices (0, cache_position, 0, 0) will not work correctly for batched decoding (i.e., when batch size > 1). It will only update the KV cache for the first sequence in the batch.
To support batching correctly, you should use the .at[...].set(...) syntax, which is aware of batch dimensions.
| k = jax.lax.dynamic_update_slice(k_cache, k, (0, cache_position, 0, 0)) | |
| v = jax.lax.dynamic_update_slice(v_cache, v, (0, cache_position, 0, 0)) | |
| k = k_cache.at[:, cache_position:cache_position+1].set(k) | |
| v = v_cache.at[:, cache_position:cache_position+1].set(v) |
|
|
||
| optimizer.update(lora_params, lora_grads) | ||
|
|
||
| print(f"Step {step}: loss = {float(loss):.4f}") |
This reverts commit 0e0ab53.
pcmoritz
left a comment
There was a problem hiding this comment.
I updated the code now, it turns out trying to share too much between llama3 and qwen3 actually just makes the code harder to understand and doesn't have too many benefits (e.g. vllm, sglang, torch titan also don't share), so I refactored the code to only share truly common layers.
In order to be able to run the rl_loop.py with #657, we need to implement string stop sequences.
This PR adds LLama3.2 models support to Tx. LLlam3 and Qwen3 mostly share the same architecture except for the QK-Norm layers present in Qwen3 but absent in LLama3. So both models share certain layers. --------- Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
In order to be able to run the rl_loop.py with NovaSky-AI#657, we need to implement string stop sequences.
This PR adds LLama3.2 models support to Tx. LLlam3 and Qwen3 mostly share the same architecture except for the QK-Norm layers present in Qwen3 but absent in LLama3. So both models share certain layers.