Skip to content

[tx] Add LLama3 support#657

Merged
pcmoritz merged 34 commits intoNovaSky-AI:mainfrom
atemaguer:feat-llama3-support
Dec 15, 2025
Merged

[tx] Add LLama3 support#657
pcmoritz merged 34 commits intoNovaSky-AI:mainfrom
atemaguer:feat-llama3-support

Conversation

@atemaguer
Copy link
Contributor

This PR adds LLama3.2 models support to Tx. LLlam3 and Qwen3 mostly share the same architecture except for the QK-Norm layers present in Qwen3 but absent in LLama3. So both models share certain layers.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Llama3 models, including a new implementation in tx.models.llama3, along with comprehensive unit and integration tests. A significant improvement is the refactoring of common layers like RMSNorm and SwiGLUMLP into a shared tx.layers.common module, which cleans up the existing Qwen3 model implementation and promotes code reuse.

My review has identified two main issues. First, a critical bug is introduced by the change to allow string-based stop sequences in the sampling parameters. The API is updated, but the backend generation logic is not, which will cause runtime failures. Second, there is a high-severity bug in the is_lora_param method within the new Llama3ForCausalLM model that prevents LoRA parameters from being identified for training, rendering LoRA fine-tuning ineffective.

Apart from these issues, the overall implementation and testing strategy are excellent.

max_tokens: int | None = None
seed: int | None = None
stop: Sequence[int] | None = None
stop: Sequence[int] | Sequence[str] | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stop parameter in SamplingParams now accepts Sequence[str], but the generation logic is not updated to handle string-based stop sequences. The current implementation in tx.utils.generator.GeneratorMixin.generate expects a list of integer token IDs and will raise an error when it tries to convert a list containing strings into a jnp.array of integers. This will lead to a runtime crash for any request that uses string stop sequences.

To fix this, the string stop sequences need to be tokenized before being used in the generation loop. This likely requires passing the tokenizer to the generation function or handling tokenization within the API layer before creating the generation request.

@pcmoritz pcmoritz added the tx label Nov 12, 2025
@atemaguer
Copy link
Contributor Author

@pcmoritz , any thoughts about this PR?

@atemaguer
Copy link
Contributor Author

@tyler-griggs, got thoughts about this PR?

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing this up @atemaguer!

max_tokens: int | None = None
seed: int | None = None
stop: Sequence[int] | None = None
stop: Sequence[int] | Sequence[str] | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why these updates to stop were needed -- could you please explain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because the Tinker client API update supports providing stop sequences as well. Also, some of the examples were failing, especially when using Llama 3 as the base model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense that it should be updated if the Tinker client API should support stop sequences, but I don't think our generation logic actually supports stop sequences right now. It seems like we should push this set of changes into a later PR when we actually add support for stop strings sequences rather than just stop tokens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the observation you made that Qwen3 and Llama3 are very similar, could we simplify this even further such that Qwen3 inherits from Llama3 and adds fairly light-weight additional logic? E.g., the transformers library takes this pattern: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modular_qwen3.py

We could likely further reduce the duplicate code between the two models, which will make it easier to keep them consistent going forward. But let me know if this actually doesn't make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On it

@atemaguer
Copy link
Contributor Author

@tyler-griggs , let me know what you think of these new updates

self.config = config

# Token embeddings
self.embed_tokens = nnx.Embed(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we'd like to have LoRA in the embedding layer for Llama3 (as we do in Qwen3). Is the plan to do this in a follow up PR?

# Get rope_theta from config
from tx.layers.common import apply_rope

rope_theta = getattr(self.config, "rope_theta", 500000.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would not set default parameters in the code like this, but rather inherit the defaults from the base config itself. Is there a reason to use getattr here or can we avoid it?

max_tokens: int | None = None
seed: int | None = None
stop: Sequence[int] | None = None
stop: Sequence[int] | Sequence[str] | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense that it should be updated if the Tinker client API should support stop sequences, but I don't think our generation logic actually supports stop sequences right now. It seems like we should push this set of changes into a later PR when we actually add support for stop strings sequences rather than just stop tokens.

@tyler-griggs
Copy link
Member

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for Llama3 models and refactors the Qwen3 implementation to share common layers, which is a great step towards code reuse. The new Llama3 model and its associated tests are well-implemented. My review primarily focuses on improving the inheritance structure between the Llama3 and Qwen3 models. Currently, several Qwen3 classes inherit from their Llama3 counterparts but completely override methods with nearly identical code, leading to significant code duplication. I've provided suggestions to refactor the Llama3 base classes to be more extensible, which will reduce duplication and improve long-term maintainability. I also noted a minor issue with a local import.



class Qwen3Attention(nnx.Module):
class Qwen3Attention(Llama3Attention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While it's great that Qwen3Attention inherits from Llama3Attention, the __call__ method is almost a complete copy of the parent's implementation, with the only addition being the QK-Norm step. This code duplication makes maintenance harder.

Consider refactoring Llama3Attention.__call__ to include a hook for processing Q and K tensors, which can be overridden by subclasses. This would allow Qwen3Attention to simply implement the QK-Norm logic without duplicating the entire attention mechanism.

For example, in Llama3Attention:

    def _process_q_k(self, q: jax.Array, k: jax.Array) -> tuple[jax.Array, jax.Array]:
        # No-op for Llama3
        return q, k

    def __call__(self, ...):
        # ... project q, k, v ...
        q, k = self._process_q_k(q, k)
        # ... apply RoPE, attention, etc. ...

Then, Qwen3Attention can just override _process_q_k and inherit __call__:

class Qwen3Attention(Llama3Attention):
    # ... __init__ ...
    def _process_q_k(self, q: jax.Array, k: jax.Array) -> tuple[jax.Array, jax.Array]:
        q = self.q_norm(q)
        k = self.k_norm(k)
        return q, k

Since llama3.py is part of this PR, this refactoring should be straightforward to apply.



class Qwen3Model(nnx.Module):
class Qwen3Model(Llama3Model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Qwen3Model inherits from Llama3Model but completely overrides both __init__ and __call__ without calling super(). This makes the inheritance misleading and introduces significant code duplication, which will be a maintenance burden. The only differences are the types of the embedding and decoder layers, and the arguments passed to embed_tokens.

A better approach would be to make Llama3Model more generic so that Qwen3Model can reuse its structure. For example, you could parameterize the layer classes in Llama3Model.__init__:

# In Llama3Model
def __init__(self, config, *, dtype, rngs, embedding_cls=nnx.Embed, decoder_layer_cls=Llama3DecoderLayer, **embedding_kwargs):
    ...
    self.embed_tokens = embedding_cls(..., **embedding_kwargs)
    self.layers = nnx.List([decoder_layer_cls(...)])
    ...

This would simplify Qwen3Model to:

class Qwen3Model(Llama3Model):
    def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None:
        # Pass Qwen3-specific classes and LoRA-related kwargs for LoRAEmbed
        super().__init__(config, dtype=dtype, rngs=rngs,
                         embedding_cls=LoRAEmbed,
                         decoder_layer_cls=Qwen3DecoderLayer,
                         max_lora_adapters=config.max_lora_adapters,
                         max_lora_rank=config.max_lora_rank)

A similar refactoring could be applied to __call__ to handle the different arguments for embed_tokens.



class Qwen3ForCausalLM(nnx.Module, GeneratorMixin):
class Qwen3ForCausalLM(Llama3ForCausalLM):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to Qwen3Model, Qwen3ForCausalLM inherits from Llama3ForCausalLM but overrides __init__ and __call__ with almost identical code, which is not ideal for maintainability. The key differences are the underlying model class (Qwen3Model vs Llama3Model) and the lm_head implementation (LoRALinear vs nnx.Linear).

To improve this, you could refactor Llama3ForCausalLM to be more configurable, allowing subclasses to specify the model class and lm_head implementation. This would allow Qwen3ForCausalLM to inherit the core logic and only define what's different, greatly reducing code duplication.

k = self.k_norm(k)

# Get rope_theta from config
from tx.layers.common import apply_rope
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This local import should be moved to the top of the file to follow standard Python conventions and improve readability.

@atemaguer
Copy link
Contributor Author

@tyler-griggs , let me know what you think about these new updates

@atemaguer
Copy link
Contributor Author

Btw @pcmoritz, it appears like this PR will cause conflicting changes with the FSDP support changes. Any chance you could review and merge this PR and then add the FSDP changes after?

@@ -1,8 +1,7 @@
"""Background engine for processing training requests."""

import argparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little suprised about the diff in this file, maybe you need to merge main into the PR? These changes should already be on the current master :)

@atemaguer
Copy link
Contributor Author

@pcmoritz, I merged the changes from main. let me know if there's anything else that needs changing or improving

@pcmoritz
Copy link
Collaborator

Thanks a lot for all the work, I'll get this PR merged next :)

I'm planning do a little bit of restructuring, it is probably easiest to just edit it directly :)

@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Llama3 models, including LoRA training and generation. The changes are well-structured, introducing a generic ModelConfig and moving common layers like RMSNorm and apply_rope into separate files for reuse. New tests for the Llama3 model and its LoRA training capabilities are also included, which is great.

I've found one critical issue in the new Llama3Attention implementation related to batched decoding that needs to be addressed. This same issue appears to exist in the Qwen3Attention model as well and should be fixed there too. I also have a minor suggestion for code cleanup in one of the new test files. Overall, this is a solid contribution.

Comment on lines +100 to +101
k = jax.lax.dynamic_update_slice(k_cache, k, (0, cache_position, 0, 0))
v = jax.lax.dynamic_update_slice(v_cache, v, (0, cache_position, 0, 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of jax.lax.dynamic_update_slice with hardcoded start indices (0, cache_position, 0, 0) will not work correctly for batched decoding (i.e., when batch size > 1). It will only update the KV cache for the first sequence in the batch.

To support batching correctly, you should use the .at[...].set(...) syntax, which is aware of batch dimensions.

Suggested change
k = jax.lax.dynamic_update_slice(k_cache, k, (0, cache_position, 0, 0))
v = jax.lax.dynamic_update_slice(v_cache, v, (0, cache_position, 0, 0))
k = k_cache.at[:, cache_position:cache_position+1].set(k)
v = v_cache.at[:, cache_position:cache_position+1].set(v)


optimizer.update(lora_params, lora_grads)

print(f"Step {step}: loss = {float(loss):.4f}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes. It's best to remove it from the final test code to keep the test output clean. If you need to output information during tests, consider using Python's logging module.

Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code now, it turns out trying to share too much between llama3 and qwen3 actually just makes the code harder to understand and doesn't have too many benefits (e.g. vllm, sglang, torch titan also don't share), so I refactored the code to only share truly common layers.

@pcmoritz pcmoritz merged commit 1fe9dd4 into NovaSky-AI:main Dec 15, 2025
4 of 5 checks passed
pcmoritz added a commit that referenced this pull request Dec 17, 2025
In order to be able to run the rl_loop.py with
#657, we need to implement
string stop sequences.
dzorlu pushed a commit to fleet-ai/SkyRL that referenced this pull request Feb 4, 2026
This PR adds LLama3.2 models support to Tx. LLlam3 and Qwen3 mostly
share the same architecture except for the QK-Norm layers present in
Qwen3 but absent in LLama3. So both models share certain layers.

---------

Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
dzorlu pushed a commit to fleet-ai/SkyRL that referenced this pull request Feb 4, 2026
In order to be able to run the rl_loop.py with
NovaSky-AI#657, we need to implement
string stop sequences.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants