-
Notifications
You must be signed in to change notification settings - Fork 285
Support Qwen3next and prefix cache. #1115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @sufubao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the Qwen3next model into the LightLLM framework, enhancing its capability to support advanced model architectures. A significant part of this integration involves a sophisticated hybrid prefix caching mechanism and a refactored memory management system tailored to Qwen3next's unique linear attention and shared expert layers. The changes also introduce several high-performance Triton kernels to ensure efficient execution of these new model components, alongside general improvements to model loading parallelism. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for the Qwen3Next model, which appears to be a hybrid model using both standard and linear attention mechanisms. The changes are extensive, touching upon memory management, request management, prefix caching, and the core model inference logic. A significant part of the work is the introduction of a HybridRadixCache that caches not only the KV state but also the state buffers required for the linear attention layers. A specialized backend for Qwen3Next is also added to handle caching intermediate states during chunked prefilling. The refactoring of MemoryManager into a BaseAllocator is a good improvement for code structure.
My review has identified one critical issue related to model selection logic that needs to be addressed. I've also provided suggestions to improve performance and maintainability in the new complex components. Overall, this is a substantial and well-engineered contribution to support a new, complex model architecture.
| class HybridTreeNode(TreeNode): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.children_list: List[HybridTreeNode] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The HybridTreeNode uses a list to store its children. In methods like _insert_helper_no_recursion and _match_prefix_helper_no_recursion, this list is iterated linearly to find a matching child. If a node has a large number of children (i.e., a high branching factor), this linear scan could become a performance bottleneck, changing the lookup complexity from O(N) to O(1) on average. Consider using a dictionary mapping the first token ID of the child's key to the child node, similar to the implementation in the base RadixCache, to maintain efficient lookups.
| class Qwen3NextTransformerLayerInfer(Qwen3MOETransformerLayerInfer): | ||
| def __init__(self, layer_num, network_config, mode=[]): | ||
| super().__init__(layer_num, network_config, mode) | ||
| self.is_linear = (layer_num + 1) % network_config["full_attention_interval"] != 0 | ||
| self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0) | ||
|
|
||
| if self.is_linear: | ||
| self.linear_attn_infer = Qwen3NextGatedDeltaNetInfer(network_config, layer_num, self.tp_world_size_) | ||
| return | ||
|
|
||
| @override | ||
| def _bind_norm(self): | ||
| pass | ||
|
|
||
| def _ffn_with_shared_expert( | ||
| self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight | ||
| ) -> torch.Tensor: | ||
| input = input.view(-1, self.embed_dim_) | ||
| up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) | ||
| ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) | ||
| silu_and_mul_fwd(up_gate_out, ffn1_out) | ||
| ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) | ||
| shared_expert_out = F.sigmoid(layer_weight.shared_expert_gate.mm(input)) * ffn2_out | ||
| moe_out = self._ffn(input, infer_state, layer_weight) | ||
| return shared_expert_out + moe_out | ||
|
|
||
| @override | ||
| def _att_norm( | ||
| self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight | ||
| ) -> torch.Tensor: | ||
| out = self.alloc_tensor(input.shape, input.dtype) | ||
| gemma_rmsnorm_forward(input, layer_weight.att_norm_weight_.weight, self.eps_, out=out) | ||
| return out | ||
|
|
||
| @override | ||
| def _ffn_norm( | ||
| self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight | ||
| ) -> torch.Tensor: | ||
| out = self.alloc_tensor(input.shape, input.dtype) | ||
| gemma_rmsnorm_forward(input, layer_weight.ffn_norm_weight_.weight, self.eps_, out=out) | ||
| return out | ||
|
|
||
| @override | ||
| def _get_qkv( | ||
| self, | ||
| input: torch.Tensor, | ||
| infer_state: LlamaInferStateInfo, | ||
| layer_weight: Qwen3NextTransformerLayerWeight, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| input = input.view(-1, self.embed_dim_) | ||
| q = layer_weight.q_proj.mm(input) | ||
| cache_kv = layer_weight.kv_proj.mm( | ||
| input.view(-1, self.embed_dim_), | ||
| ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) | ||
|
|
||
| gemma_rmsnorm_forward( | ||
| q.view(-1, self.head_dim_), | ||
| layer_weight.q_norm_weight_.weight, | ||
| eps=self.eps_, | ||
| out=q.view(-1, self.head_dim_), | ||
| ) | ||
|
|
||
| cache_kv[:, : self.tp_k_head_num_, :] = gemma_rmsnorm_forward( | ||
| cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), | ||
| layer_weight.k_norm_weight_.weight, | ||
| eps=self.eps_, | ||
| ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) | ||
|
|
||
| rotary_emb_fwd( | ||
| q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
| cache_kv[:, : self.tp_k_head_num_, :], | ||
| infer_state.position_cos, | ||
| infer_state.position_sin, | ||
| partial_rotary_factor=self.partial_rotary_factor, | ||
| ) | ||
| return q, cache_kv | ||
|
|
||
| @override | ||
| def _get_o( | ||
| self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight | ||
| ) -> torch.Tensor: | ||
| input = input * layer_weight._gate | ||
| layer_weight._gate = None | ||
| o_tensor = layer_weight.o_proj.mm(input) | ||
| return o_tensor | ||
|
|
||
| def _context_full_attn( | ||
| self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight | ||
| ): | ||
| q, cache_kv = self._get_qkv(input, infer_state, layer_weight) | ||
| input = None | ||
| self._post_cache_kv(cache_kv, infer_state, layer_weight) | ||
| o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) | ||
| q = None | ||
| o = self._get_o(o, infer_state, layer_weight) | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| return o | ||
|
|
||
| def context_forward( | ||
| self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight | ||
| ): | ||
| input1 = self._att_norm(input_embdings, infer_state, layer_weight) | ||
| if self.is_linear: | ||
| o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=True, infer_cls=self) | ||
| else: | ||
| layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)) | ||
| o = self._context_full_attn(input1, infer_state, layer_weight) | ||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||
| o = None | ||
|
|
||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
| ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||
| return input_embdings | ||
|
|
||
| def _token_full_attn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight): | ||
| q, cache_kv = self._get_qkv(input, infer_state, layer_weight) | ||
| input = None | ||
| self._post_cache_kv(cache_kv, infer_state, layer_weight) | ||
| o = self._token_attention_kernel(q, infer_state, layer_weight) | ||
| q = None | ||
| o = self._get_o(o, infer_state, layer_weight) | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| return o | ||
|
|
||
| def token_forward( | ||
| self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight | ||
| ): | ||
| input1 = self._att_norm(input_embdings, infer_state, layer_weight) | ||
| if self.is_linear: | ||
| o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=False, infer_cls=self) | ||
| else: | ||
| layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)) | ||
| o = self._token_full_attn(input1, infer_state, layer_weight) | ||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||
| o = None | ||
|
|
||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
| ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||
| return input_embdings | ||
|
|
||
|
|
||
| class Qwen3NextGatedDeltaNetInfer: | ||
| def __init__(self, network_config, layer_idx, tp_world_size_): | ||
| self.network_config_ = network_config | ||
| self.layer_idx_ = layer_idx | ||
| self.tp_world_size_ = tp_world_size_ | ||
| self.num_v_heads = self.network_config_["linear_num_value_heads"] | ||
| self.num_k_heads = self.network_config_["linear_num_key_heads"] | ||
| self.head_k_dim = self.network_config_["linear_key_head_dim"] | ||
| self.head_v_dim = self.network_config_["linear_value_head_dim"] | ||
| self.key_dim = self.head_k_dim * self.num_k_heads | ||
| self.value_dim = self.head_v_dim * self.num_v_heads | ||
| self.conv_kernel_dim = self.network_config_["linear_conv_kernel_dim"] | ||
| self.activation = self.network_config_["hidden_act"] | ||
| self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ | ||
| self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ | ||
| self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ | ||
| self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ | ||
| self.tp_key_dim = self.key_dim // self.tp_world_size_ | ||
| self.tp_value_dim = self.value_dim // self.tp_world_size_ | ||
| assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" | ||
| self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads | ||
|
|
||
| def _fix_query_key_value_ba_ordering(self, mixed_qkvzba): | ||
| """ | ||
| Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. | ||
| """ | ||
| mixed_qkvz, mixed_ba = torch.split(mixed_qkvzba, [self.tp_qkvz_dim, self.tp_ba_dim], dim=-1) | ||
|
|
||
| mixed_qkvz = mixed_qkvz.view( | ||
| -1, | ||
| self.tp_num_k_heads, | ||
| self.head_k_dim + self.head_k_dim + (self.head_v_dim + self.head_v_dim) * self.num_v_heads_per_k_head, | ||
| ) | ||
| mixed_ba = mixed_ba.view(-1, self.tp_num_k_heads, 2 * self.num_v_heads_per_k_head) | ||
|
|
||
| qkvz_split_list = [ | ||
| self.head_k_dim, | ||
| self.head_k_dim, | ||
| (self.num_v_heads_per_k_head * self.head_v_dim), | ||
| (self.num_v_heads_per_k_head * self.head_v_dim), | ||
| ] | ||
| (query, key, value, z) = torch.split(mixed_qkvz, qkvz_split_list, dim=2) | ||
| (b, a) = torch.split(mixed_ba, [self.num_v_heads_per_k_head, self.num_v_heads_per_k_head], dim=2) | ||
|
|
||
| query = query.reshape(-1, self.tp_num_k_heads * self.head_k_dim) | ||
| key = key.reshape(-1, self.tp_num_k_heads * self.head_k_dim) | ||
| value = value.reshape(-1, self.tp_num_v_heads * self.head_v_dim) | ||
| z = z.reshape(-1, self.tp_num_v_heads, self.head_v_dim) | ||
| b = b.reshape(-1, self.tp_num_v_heads) | ||
| a = a.reshape(-1, self.tp_num_v_heads) | ||
|
|
||
| return query, key, value, z, b, a | ||
|
|
||
| def _rearrange_mixed_qkv(self, mixed_qkv): | ||
| if mixed_qkv is None: | ||
| return None, None, None | ||
| query, key, value = torch.split( | ||
| mixed_qkv, | ||
| [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], | ||
| dim=-1, | ||
| ) | ||
| query, key = map(lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), (query, key)) | ||
| value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) | ||
| return query, key, value | ||
|
|
||
| def _linear_attn( | ||
| self, | ||
| input: torch.Tensor, | ||
| infer_state: LlamaInferStateInfo, | ||
| layer_weight: Qwen3NextTransformerLayerWeight, | ||
| is_prefill: bool, | ||
| infer_cls: Qwen3NextTransformerLayerInfer, | ||
| ): | ||
| assert layer_weight.is_linear, "layer_weight must be linear" | ||
| assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager) | ||
| assert isinstance(infer_state.req_manager, Qwen3NextReqManager) | ||
| input = input.view(-1, infer_cls.embed_dim_) | ||
| buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx] | ||
| conv_states, ssm_states = infer_state.mem_manager.get_state_cache_buffer(self.layer_idx_) | ||
|
|
||
| mixed_qkvzba = layer_weight.linear_in_proj.mm(input) | ||
| q, k, v, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba) | ||
| mixed_qkv = torch.cat([q, k, v], dim=-1) | ||
|
|
||
| if is_prefill: | ||
| mixed_qkv = mixed_qkv.transpose(0, 1) | ||
| out_tensor = infer_cls.alloc_tensor(mixed_qkv.shape, mixed_qkv.dtype, device=mixed_qkv.device) | ||
| causal_conv1d_fn( | ||
| mixed_qkv, | ||
| layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1), | ||
| layer_weight.linear_conv1d.mm_param.bias, | ||
| conv_states.transpose(1, 2), | ||
| infer_state.b1_cu_q_seq_len, | ||
| out=out_tensor, | ||
| cache_indices=buffer_idx, | ||
| activation=self.activation, | ||
| ) | ||
| mixed_qkv = out_tensor.transpose(0, 1) | ||
| else: | ||
| mixed_qkv = causal_conv1d_update( | ||
| mixed_qkv, | ||
| conv_states.transpose(1, 2), | ||
| layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1), | ||
| layer_weight.linear_conv1d.mm_param.bias, | ||
| self.activation, | ||
| conv_state_indices=buffer_idx, | ||
| validate_data=True, | ||
| ) | ||
|
|
||
| # Rearrange mixed_qkv to query, key, value | ||
| query, key, value = self._rearrange_mixed_qkv(mixed_qkv) | ||
|
|
||
| # Compute beta and g | ||
| beta = b.sigmoid() | ||
| g = fused_gdn_gating(layer_weight.linear_A_log.weight, a, layer_weight.linear_dt_bias.weight) | ||
| g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) | ||
|
|
||
| if is_prefill: | ||
| initial_state = ssm_states[buffer_idx].contiguous() | ||
| (core_attn_out, last_recurrent_state,) = chunk_gated_delta_rule( | ||
| q=query, | ||
| k=key, | ||
| v=value, | ||
| g=g, | ||
| beta=beta, | ||
| initial_state=initial_state, | ||
| output_final_state=True, | ||
| cu_seqlens=infer_state.b1_cu_q_seq_len, | ||
| head_first=False, | ||
| use_qk_l2norm_in_kernel=True, | ||
| ) | ||
| # Update SSM state with final state | ||
| ssm_states[buffer_idx, ...] = last_recurrent_state.to(ssm_states.dtype) | ||
| else: | ||
| batch_size = input.shape[0] | ||
| cu_seqlens = torch.arange(0, batch_size + 1, dtype=torch.int32, device=input.device) | ||
| (core_attn_out, last_recurrent_state,) = fused_recurrent_gated_delta_rule( | ||
| q=query, | ||
| k=key, | ||
| v=value, | ||
| g=g, | ||
| beta=beta, | ||
| initial_state=ssm_states, | ||
| inplace_final_state=True, | ||
| cu_seqlens=cu_seqlens, | ||
| ssm_state_indices=buffer_idx, | ||
| use_qk_l2norm_in_kernel=True, | ||
| ) | ||
|
|
||
| z_shape_og = z.shape | ||
| core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) | ||
| z = z.reshape(-1, z.shape[-1]) | ||
| norm_out = infer_cls.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) | ||
| gated_rmsnorm_forward( | ||
| core_attn_out, | ||
| layer_weight.linear_norm.weight, | ||
| layer_weight.linear_norm.bias, | ||
| infer_cls.eps_, | ||
| z, | ||
| out=norm_out, | ||
| ) | ||
| core_attn_out = norm_out.reshape(z_shape_og) | ||
| core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") | ||
|
|
||
| output = layer_weight.linear_out_proj.mm(core_attn_out) | ||
| if infer_cls.tp_world_size_ > 1: | ||
| all_reduce(output, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) | ||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classes Qwen3NextTransformerLayerInfer and Qwen3NextGatedDeltaNetInfer implement the complex, hybrid attention logic for the Qwen3Next model. The code is very dense and lacks comments, which makes it difficult to understand and maintain. Adding high-level comments to explain the purpose of key methods (e.g., _fix_query_key_value_ba_ordering, _linear_attn) and the overall data flow would be highly beneficial. For instance, a comment explaining what "linear attention" means in this context (i.e., a state-space model) and how it interacts with the convolutional and recurrent kernels would greatly improve readability.
No description provided.