Skip to content

Conversation

@sufubao
Copy link
Collaborator

@sufubao sufubao commented Nov 18, 2025

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @sufubao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the Qwen3next model into the LightLLM framework, enhancing its capability to support advanced model architectures. A significant part of this integration involves a sophisticated hybrid prefix caching mechanism and a refactored memory management system tailored to Qwen3next's unique linear attention and shared expert layers. The changes also introduce several high-performance Triton kernels to ensure efficient execution of these new model components, alongside general improvements to model loading parallelism.

Highlights

  • Qwen3next Model Support: Introduced comprehensive support for the Qwen3next model architecture, including its unique linear attention mechanism and shared expert layers, by adding new layer inference and weight classes.
  • Hybrid Radix Cache for Prefix Caching: Implemented a new HybridRadixCache system that extends the existing radix cache to manage both token indices and state buffer indices, crucial for efficiently handling intermediate states in models like Qwen3next.
  • Memory Management Refactoring: Refactored the core memory management by introducing a BaseAllocator class and a specialized Qwen3NextMemoryManager to handle distinct cache types (KV cache, conv_state, ssm_state) required by the Qwen3next model's linear attention.
  • New Triton Kernels for Performance: Added several highly optimized Triton kernels, including causal_conv1d, gated_rmsnorm, gemma_rmsnorm, fused_gdn_gating, and a suite of Flash Linear Attention (FLA) operations, to accelerate Qwen3next's computational graphs.
  • Increased Weight Loading Parallelism: Increased the default number of workers for loading Hugging Face model weights from 1 to 16, which can significantly speed up model initialization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Qwen3Next model, which appears to be a hybrid model using both standard and linear attention mechanisms. The changes are extensive, touching upon memory management, request management, prefix caching, and the core model inference logic. A significant part of the work is the introduction of a HybridRadixCache that caches not only the KV state but also the state buffers required for the linear attention layers. A specialized backend for Qwen3Next is also added to handle caching intermediate states during chunked prefilling. The refactoring of MemoryManager into a BaseAllocator is a good improvement for code structure.

My review has identified one critical issue related to model selection logic that needs to be addressed. I've also provided suggestions to improve performance and maintainability in the new complex components. Overall, this is a substantial and well-engineered contribution to support a new, complex model architecture.

class HybridTreeNode(TreeNode):
def __init__(self):
super().__init__()
self.children_list: List[HybridTreeNode] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The HybridTreeNode uses a list to store its children. In methods like _insert_helper_no_recursion and _match_prefix_helper_no_recursion, this list is iterated linearly to find a matching child. If a node has a large number of children (i.e., a high branching factor), this linear scan could become a performance bottleneck, changing the lookup complexity from O(N) to O(1) on average. Consider using a dictionary mapping the first token ID of the child's key to the child node, similar to the implementation in the base RadixCache, to maintain efficient lookups.

Comment on lines +27 to +345
class Qwen3NextTransformerLayerInfer(Qwen3MOETransformerLayerInfer):
def __init__(self, layer_num, network_config, mode=[]):
super().__init__(layer_num, network_config, mode)
self.is_linear = (layer_num + 1) % network_config["full_attention_interval"] != 0
self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0)

if self.is_linear:
self.linear_attn_infer = Qwen3NextGatedDeltaNetInfer(network_config, layer_num, self.tp_world_size_)
return

@override
def _bind_norm(self):
pass

def _ffn_with_shared_expert(
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
) -> torch.Tensor:
input = input.view(-1, self.embed_dim_)
up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input)
ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype)
silu_and_mul_fwd(up_gate_out, ffn1_out)
ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out)
shared_expert_out = F.sigmoid(layer_weight.shared_expert_gate.mm(input)) * ffn2_out
moe_out = self._ffn(input, infer_state, layer_weight)
return shared_expert_out + moe_out

@override
def _att_norm(
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
) -> torch.Tensor:
out = self.alloc_tensor(input.shape, input.dtype)
gemma_rmsnorm_forward(input, layer_weight.att_norm_weight_.weight, self.eps_, out=out)
return out

@override
def _ffn_norm(
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
) -> torch.Tensor:
out = self.alloc_tensor(input.shape, input.dtype)
gemma_rmsnorm_forward(input, layer_weight.ffn_norm_weight_.weight, self.eps_, out=out)
return out

@override
def _get_qkv(
self,
input: torch.Tensor,
infer_state: LlamaInferStateInfo,
layer_weight: Qwen3NextTransformerLayerWeight,
) -> Tuple[torch.Tensor, torch.Tensor]:
input = input.view(-1, self.embed_dim_)
q = layer_weight.q_proj.mm(input)
cache_kv = layer_weight.kv_proj.mm(
input.view(-1, self.embed_dim_),
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

gemma_rmsnorm_forward(
q.view(-1, self.head_dim_),
layer_weight.q_norm_weight_.weight,
eps=self.eps_,
out=q.view(-1, self.head_dim_),
)

cache_kv[:, : self.tp_k_head_num_, :] = gemma_rmsnorm_forward(
cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]),
layer_weight.k_norm_weight_.weight,
eps=self.eps_,
).view(-1, self.tp_k_head_num_, cache_kv.shape[-1])

rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
partial_rotary_factor=self.partial_rotary_factor,
)
return q, cache_kv

@override
def _get_o(
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
) -> torch.Tensor:
input = input * layer_weight._gate
layer_weight._gate = None
o_tensor = layer_weight.o_proj.mm(input)
return o_tensor

def _context_full_attn(
self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
):
q, cache_kv = self._get_qkv(input, infer_state, layer_weight)
input = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
return o

def context_forward(
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
if self.is_linear:
o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=True, infer_cls=self)
else:
layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1))
o = self._context_full_attn(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def _token_full_attn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight):
q, cache_kv = self._get_qkv(input, infer_state, layer_weight)
input = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
return o

def token_forward(
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
if self.is_linear:
o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=False, infer_cls=self)
else:
layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1))
o = self._token_full_attn(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings


class Qwen3NextGatedDeltaNetInfer:
def __init__(self, network_config, layer_idx, tp_world_size_):
self.network_config_ = network_config
self.layer_idx_ = layer_idx
self.tp_world_size_ = tp_world_size_
self.num_v_heads = self.network_config_["linear_num_value_heads"]
self.num_k_heads = self.network_config_["linear_num_key_heads"]
self.head_k_dim = self.network_config_["linear_key_head_dim"]
self.head_v_dim = self.network_config_["linear_value_head_dim"]
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_dim = self.network_config_["linear_conv_kernel_dim"]
self.activation = self.network_config_["hidden_act"]
self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_
self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_
self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_
self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_
self.tp_key_dim = self.key_dim // self.tp_world_size_
self.tp_value_dim = self.value_dim // self.tp_world_size_
assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads"
self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads

def _fix_query_key_value_ba_ordering(self, mixed_qkvzba):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
"""
mixed_qkvz, mixed_ba = torch.split(mixed_qkvzba, [self.tp_qkvz_dim, self.tp_ba_dim], dim=-1)

mixed_qkvz = mixed_qkvz.view(
-1,
self.tp_num_k_heads,
self.head_k_dim + self.head_k_dim + (self.head_v_dim + self.head_v_dim) * self.num_v_heads_per_k_head,
)
mixed_ba = mixed_ba.view(-1, self.tp_num_k_heads, 2 * self.num_v_heads_per_k_head)

qkvz_split_list = [
self.head_k_dim,
self.head_k_dim,
(self.num_v_heads_per_k_head * self.head_v_dim),
(self.num_v_heads_per_k_head * self.head_v_dim),
]
(query, key, value, z) = torch.split(mixed_qkvz, qkvz_split_list, dim=2)
(b, a) = torch.split(mixed_ba, [self.num_v_heads_per_k_head, self.num_v_heads_per_k_head], dim=2)

query = query.reshape(-1, self.tp_num_k_heads * self.head_k_dim)
key = key.reshape(-1, self.tp_num_k_heads * self.head_k_dim)
value = value.reshape(-1, self.tp_num_v_heads * self.head_v_dim)
z = z.reshape(-1, self.tp_num_v_heads, self.head_v_dim)
b = b.reshape(-1, self.tp_num_v_heads)
a = a.reshape(-1, self.tp_num_v_heads)

return query, key, value, z, b, a

def _rearrange_mixed_qkv(self, mixed_qkv):
if mixed_qkv is None:
return None, None, None
query, key, value = torch.split(
mixed_qkv,
[self.tp_key_dim, self.tp_key_dim, self.tp_value_dim],
dim=-1,
)
query, key = map(lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), (query, key))
value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
return query, key, value

def _linear_attn(
self,
input: torch.Tensor,
infer_state: LlamaInferStateInfo,
layer_weight: Qwen3NextTransformerLayerWeight,
is_prefill: bool,
infer_cls: Qwen3NextTransformerLayerInfer,
):
assert layer_weight.is_linear, "layer_weight must be linear"
assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager)
assert isinstance(infer_state.req_manager, Qwen3NextReqManager)
input = input.view(-1, infer_cls.embed_dim_)
buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx]
conv_states, ssm_states = infer_state.mem_manager.get_state_cache_buffer(self.layer_idx_)

mixed_qkvzba = layer_weight.linear_in_proj.mm(input)
q, k, v, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba)
mixed_qkv = torch.cat([q, k, v], dim=-1)

if is_prefill:
mixed_qkv = mixed_qkv.transpose(0, 1)
out_tensor = infer_cls.alloc_tensor(mixed_qkv.shape, mixed_qkv.dtype, device=mixed_qkv.device)
causal_conv1d_fn(
mixed_qkv,
layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1),
layer_weight.linear_conv1d.mm_param.bias,
conv_states.transpose(1, 2),
infer_state.b1_cu_q_seq_len,
out=out_tensor,
cache_indices=buffer_idx,
activation=self.activation,
)
mixed_qkv = out_tensor.transpose(0, 1)
else:
mixed_qkv = causal_conv1d_update(
mixed_qkv,
conv_states.transpose(1, 2),
layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1),
layer_weight.linear_conv1d.mm_param.bias,
self.activation,
conv_state_indices=buffer_idx,
validate_data=True,
)

# Rearrange mixed_qkv to query, key, value
query, key, value = self._rearrange_mixed_qkv(mixed_qkv)

# Compute beta and g
beta = b.sigmoid()
g = fused_gdn_gating(layer_weight.linear_A_log.weight, a, layer_weight.linear_dt_bias.weight)
g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta))

if is_prefill:
initial_state = ssm_states[buffer_idx].contiguous()
(core_attn_out, last_recurrent_state,) = chunk_gated_delta_rule(
q=query,
k=key,
v=value,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=infer_state.b1_cu_q_seq_len,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
# Update SSM state with final state
ssm_states[buffer_idx, ...] = last_recurrent_state.to(ssm_states.dtype)
else:
batch_size = input.shape[0]
cu_seqlens = torch.arange(0, batch_size + 1, dtype=torch.int32, device=input.device)
(core_attn_out, last_recurrent_state,) = fused_recurrent_gated_delta_rule(
q=query,
k=key,
v=value,
g=g,
beta=beta,
initial_state=ssm_states,
inplace_final_state=True,
cu_seqlens=cu_seqlens,
ssm_state_indices=buffer_idx,
use_qk_l2norm_in_kernel=True,
)

z_shape_og = z.shape
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
norm_out = infer_cls.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device)
gated_rmsnorm_forward(
core_attn_out,
layer_weight.linear_norm.weight,
layer_weight.linear_norm.bias,
infer_cls.eps_,
z,
out=norm_out,
)
core_attn_out = norm_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")

output = layer_weight.linear_out_proj.mm(core_attn_out)
if infer_cls.tp_world_size_ > 1:
all_reduce(output, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The classes Qwen3NextTransformerLayerInfer and Qwen3NextGatedDeltaNetInfer implement the complex, hybrid attention logic for the Qwen3Next model. The code is very dense and lacks comments, which makes it difficult to understand and maintain. Adding high-level comments to explain the purpose of key methods (e.g., _fix_query_key_value_ba_ordering, _linear_attn) and the overall data flow would be highly beneficial. For instance, a comment explaining what "linear attention" means in this context (i.e., a state-space model) and how it interacts with the convolutional and recurrent kernels would greatly improve readability.

@ModelTC ModelTC deleted a comment from gemini-code-assist bot Nov 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants