diff --git a/pyproject.toml b/pyproject.toml index 6a99e408..8aa762f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ mac = [ gpu = [ "mlx-lm==0.28.0", "mlx[cpu]==0.29.1", - "sglang[all]==0.5.2", + "sglang[all]==0.5.4.post1", ] benchmark = [ diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 68b55bda..0676a941 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -345,12 +345,14 @@ async def v1_chat_completions(raw_request: fastapi.Request): # Check if request_json has "rid", otherwise generate new one request_id = request_json.get("rid") if request_id is None: - request_id = uuid.uuid4() - request_json["rid"] = str(request_id) + request_id = str(uuid.uuid4()) + request_json["rid"] = request_id app.state.http_handler.create_request(request_json) app.state.http_handler.send_request(request_json) req = app.state.http_handler.processing_requests.get(request_id) + if req is None: + return create_error_response("Request not found", "RequestNotFoundError") is_stream = req.stream if is_stream: diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 4ce2a89f..18ac7df7 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -5,6 +5,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch """ +from types import SimpleNamespace from typing import List import torch @@ -67,11 +68,16 @@ def form_sgl_batch_prefill( ) -> ForwardBatch: """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" sgl_reqs = transform_requests_to_sglang(requests) + dummy_tree_cache = SimpleNamespace( + page_size=model_runner.server_args.page_size, + device=model_runner.device, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + ) schedule_batch = ScheduleBatch.init_new( reqs=sgl_reqs, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, - tree_cache=None, + tree_cache=dummy_tree_cache, model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, @@ -193,6 +199,8 @@ def form_sgl_batch_decode( def release_cuda_request(running_batch: ScheduleBatch, request_id: str): """Release KV Cache and other resources for finished/aborted requests.""" + if running_batch is None or running_batch.is_empty(): + return seq_lens_cpu = running_batch.seq_lens.cpu().numpy() idx = find_index(running_batch, request_id) req = running_batch.reqs.pop(idx) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index a005a03f..37c075aa 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -7,7 +7,6 @@ import logging import os import random -import sys from typing import Any, Dict, List, Optional, Tuple, Union import sglang @@ -70,6 +69,7 @@ def __init__( use_hpu_communicator: bool, use_xpu_communicator: bool, use_npu_communicator: bool, + use_torch_symm_mem: bool = False, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, pp_start_layer: int = 0, @@ -87,6 +87,7 @@ def __init__( use_hpu_communicator=use_hpu_communicator, use_xpu_communicator=use_xpu_communicator, use_npu_communicator=use_npu_communicator, + use_torch_symm_mem=use_torch_symm_mem, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) @@ -437,7 +438,7 @@ def monkey_patch_make_layers( # circula imports from sglang.srt.distributed import get_pp_group from sglang.srt.layers.utils import PPMissingLayer - from sglang.srt.offloader import get_offloader + from sglang.srt.utils.offloader import get_offloader assert not pp_size or num_hidden_layers >= pp_size start_layer, end_layer = get_pp_group().pp_start_layer, get_pp_group().pp_end_layer @@ -460,15 +461,15 @@ def monkey_patch_make_layers( ## TODO: Move this when sgalang supports qwen3_next pipeline parallelism def monkey_patch_qwen3_next(): - from parallax.sglang.monkey_patch import ( - qwen3_next_model as parallax_qwen3_next_model_module, - ) from parallax.sglang.monkey_patch.qwen3_next_config import ( - monkey_patch_linear_layer_ids, + apply_qwen3_next_config_monkey_patch, + ) + from parallax.sglang.monkey_patch.qwen3_next_model import ( + apply_qwen3_next_monkey_patch, ) - sys.modules["sglang.srt.models.qwen3_next"] = parallax_qwen3_next_model_module - sglang.srt.configs.qwen3_next.Qwen3NextConfig.linear_layer_ids = monkey_patch_linear_layer_ids + apply_qwen3_next_monkey_patch() + apply_qwen3_next_config_monkey_patch() ## TODO: Move this when sgalang supports gpt_oss pipeline parallelism @@ -553,6 +554,11 @@ def initialize_sgl_model_runner( attention_backend = "triton" moe_runner_backend = "triton_kernel" + architectures = config.get("architectures", []) + if architectures and any("Qwen3Next" in arch for arch in architectures): + logger.debug(f"Qwen3-Next model detected, setting kv_block_size to 1") + kv_block_size = 1 + server_args = form_sgl_server_args( original_model_path, dtype, @@ -574,8 +580,10 @@ def initialize_sgl_model_runner( model_config.hf_config.tie_word_embeddings = False model_config.hf_config.start_layer = start_layer model_config.hf_config.end_layer = end_layer + logger.debug(f"model_start_layer: {model_config.hf_config.start_layer}") logger.debug(f"model_end_layer: {model_config.hf_config.end_layer}") + model_runner = ParallaxModelRunner( model_config=model_config, mem_fraction_static=kv_cache_memory_fraction, diff --git a/src/parallax/sglang/monkey_patch/qwen3_next_config.py b/src/parallax/sglang/monkey_patch/qwen3_next_config.py index 44527a87..489add3f 100644 --- a/src/parallax/sglang/monkey_patch/qwen3_next_config.py +++ b/src/parallax/sglang/monkey_patch/qwen3_next_config.py @@ -1,17 +1,3 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """Qwen3Hybrid model configuration""" import enum @@ -31,10 +17,51 @@ class HybridLayerType(enum.Enum): @property def monkey_patch_linear_layer_ids(self): - return [ + """Return linear-attention layer ids restricted to the PP slice. + + This is intended to be bound as a property on + `sglang.srt.configs.qwen3_next.Qwen3NextConfig`. + """ + lst = [ i for i, type_value in enumerate(self.layers_block_type) if type_value == HybridLayerType.linear_attention.value and i >= self.start_layer and i < self.end_layer ] + # If no matching layer id, return at least [-1] + # just for pp + return lst if lst else [-1] + + +@property +def monkey_patch_full_attention_layer_ids(self): + """Return full-attention layer ids restricted to the PP slice. + + This is intended to be bound as a property on + `sglang.srt.configs.qwen3_next.Qwen3NextConfig`. + """ + lst = [ + i + for i, type_value in enumerate(self.layers_block_type) + if type_value == HybridLayerType.full_attention.value + and i >= self.start_layer + and i < self.end_layer + ] + # If no matching layer id, return at least [-1] + # just for pp + return lst if lst else [-1] + + +def apply_qwen3_next_config_monkey_patch(): + """Bind monkey-patch helpers to the upstream Qwen3NextConfig class. + + We attach the two helpers above as properties so callers can access + `config.linear_layer_ids` / `config.full_attention_layer_ids` the same + way upstream expects. + """ + + import sglang.srt.configs.qwen3_next as s + + s.Qwen3NextConfig.linear_layer_ids = monkey_patch_linear_layer_ids + s.Qwen3NextConfig.full_attention_layer_ids = monkey_patch_full_attention_layer_ids diff --git a/src/parallax/sglang/monkey_patch/qwen3_next_model.py b/src/parallax/sglang/monkey_patch/qwen3_next_model.py index 219d1e4f..04872a49 100644 --- a/src/parallax/sglang/monkey_patch/qwen3_next_model.py +++ b/src/parallax/sglang/monkey_patch/qwen3_next_model.py @@ -1,757 +1,55 @@ -import enum import logging -from typing import Any, Iterable, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple -import torch from sglang.srt.configs.qwen3_next import Qwen3NextConfig -from sglang.srt.distributed import divide, get_pp_group -from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation -from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated -from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader -from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes -from sglang.srt.layers.dp_attention import ( - get_attention_tp_rank, - get_attention_tp_size, - is_dp_attention_enabled, -) -from sglang.srt.layers.layernorm import GemmaRMSNorm -from sglang.srt.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.utils import PPMissingLayer, get_layer_id -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_loader.weight_utils import ( - default_weight_loader, - sharded_weight_loader, -) -from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock -from sglang.srt.utils import add_prefix, is_cuda, make_layers, set_weight_attrs +from sglang.srt.utils import is_cuda from torch import nn logger = logging.getLogger(__name__) _is_cuda = is_cuda() -import triton -import triton.language as tl - -@triton.jit -def fused_qkvzba_split_reshape_cat_kernel( - mixed_qkv, - z, - b, - a, - mixed_qkvz, - mixed_ba, - NUM_HEADS_QK: tl.constexpr, - NUM_HEADS_V: tl.constexpr, - HEAD_QK: tl.constexpr, - HEAD_V: tl.constexpr, -): - i_bs, i_qk = tl.program_id(0), tl.program_id(1) - QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2 - BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 - QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V - q_end: tl.constexpr = HEAD_QK - blk_q_ptr = ( - mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(0, q_end) - ) - k_end: tl.constexpr = q_end + HEAD_QK - blk_k_ptr = ( - mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end) - ) - v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V - blk_v_ptr = ( - mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end) - ) - z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V - blk_z_ptr = ( - mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end) - ) - blk_q_st_ptr = ( - mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T + i_qk * HEAD_QK + tl.arange(0, HEAD_QK) - ) - blk_k_st_ptr = ( - mixed_qkv - + i_bs * NUM_HEADS_QK * QKV_DIM_T - + NUM_HEADS_QK * HEAD_QK - + i_qk * HEAD_QK - + tl.arange(0, HEAD_QK) - ) - blk_v_st_ptr = ( - mixed_qkv - + i_bs * NUM_HEADS_QK * QKV_DIM_T - + NUM_HEADS_QK * HEAD_QK * 2 - + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK - + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) - ) - blk_z_st_ptr = ( - z - + i_bs * NUM_HEADS_V * HEAD_V - + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK - + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) - ) - tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) - tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) - tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) - tl.store(blk_z_st_ptr, tl.load(blk_z_ptr)) - b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK - a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK - for i in tl.static_range(b_end): - blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i - blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i - tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) - for i in tl.static_range(b_end, a_end): - blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i - blk_a_st_ptr = a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end) - tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) - - -def fused_qkvzba_split_reshape_cat( - mixed_qkvz, - mixed_ba, - num_heads_qk, - num_heads_v, - head_qk, - head_v, -): - batch, seq_len = mixed_qkvz.shape[0], 1 - qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v - mixed_qkv = torch.empty( - [batch * seq_len, qkv_dim_t], - dtype=mixed_qkvz.dtype, - device=mixed_qkvz.device, - ) - z = torch.empty( - [batch * seq_len, num_heads_v, head_v], - dtype=mixed_qkvz.dtype, - device=mixed_qkvz.device, - ) - b = torch.empty( - [batch * seq_len, num_heads_v], - dtype=mixed_ba.dtype, - device=mixed_ba.device, +# ---- Minimal method-level monkey patch to reuse sglang source ---- +def apply_qwen3_next_monkey_patch(): + """Apply minimal monkey patches to sglang's qwen3_next to support PP without copying code. + + We override only a few methods: + - Qwen3NextModel.__init__: build layers with PP slicing, gate embed/norm by first/last rank. + - Qwen3NextModel.forward: accept/return PPProxyTensors between stages. + - Qwen3NextForCausalLM.__init__: remove single-rank assertion, keep original wiring. + - Qwen3NextForCausalLM.forward: only last rank computes logits; others pass proxies. + - Qwen3NextForCausalLM.load_weights: pre-filter weights by layer_id to load only local slice. + """ + import torch + from sglang.srt.distributed import get_pp_group + from sglang.srt.layers.dp_attention import is_dp_attention_enabled + from sglang.srt.layers.layernorm import GemmaRMSNorm + from sglang.srt.layers.utils import PPMissingLayer, get_layer_id + from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, ) - a = torch.empty_like(b) - grid = (batch * seq_len, num_heads_qk) - fused_qkvzba_split_reshape_cat_kernel[grid]( - mixed_qkv, - z, - b, - a, - mixed_qkvz, - mixed_ba, - num_heads_qk, - num_heads_v, - head_qk, - head_v, - num_warps=1, - num_stages=3, - ) - return mixed_qkv, z, b, a - - -# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) -@triton.jit -def fused_gdn_gating_kernel( - g, - A_log, - a, - dt_bias, - seq_len, - NUM_HEADS: tl.constexpr, - beta: tl.constexpr, - threshold: tl.constexpr, - BLK_HEADS: tl.constexpr, -): - i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) - head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) - off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off - mask = head_off < NUM_HEADS - blk_A_log = tl.load(A_log + head_off, mask=mask) - blk_a = tl.load(a + off, mask=mask) - blk_bias = tl.load(dt_bias + head_off, mask=mask) - x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) - softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) - blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x - tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) - - -def fused_gdn_gating( - A_log: torch.Tensor, - a: torch.Tensor, - dt_bias: torch.Tensor, - beta: float = 1.0, - threshold: float = 20.0, -) -> torch.Tensor: - batch, num_heads = a.shape - seq_len = 1 - grid = (batch, seq_len, triton.cdiv(num_heads, 8)) - g = torch.empty_like(a, dtype=torch.float32) - fused_gdn_gating_kernel[grid]( - g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 - ) - return g - - -class Qwen3GatedDeltaNet(nn.Module): - def __init__( - self, - config: Qwen3NextConfig, - layer_id: int, - alt_stream: Optional[torch.cuda.Stream] = None, - ) -> None: - super().__init__() - self.config = config - self.attn_tp_rank = get_attention_tp_rank() - self.attn_tp_size = get_attention_tp_size() - self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads - self.head_k_dim = config.linear_key_head_dim - self.head_v_dim = config.linear_value_head_dim - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - self.alt_stream = alt_stream - - self.conv_kernel_size = config.linear_conv_kernel_dim - self.layer_id = layer_id - self.activation = config.hidden_act - self.layer_norm_epsilon = config.rms_norm_eps - - # QKV - self.conv_dim = self.key_dim * 2 + self.value_dim - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.conv_dim, - bias=False, - quant_config=None, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - ) - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - # projection of the input hidden states - projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 - projection_size_ba = self.num_v_heads * 2 - - self.in_proj_qkvz = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=projection_size_qkvz, - bias=False, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - ) - self.in_proj_ba = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=projection_size_ba, - bias=False, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - ) - - query_key_settings = (self.key_dim, 0, False) - value_settings = (self.value_dim, 0, False) - - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - query_key_settings, - query_key_settings, - value_settings, - ], - self.attn_tp_size, - self.attn_tp_rank, - ) - }, - ) - - # selective projection used to make dt, B and C input dependent - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads // self.attn_tp_size)) - - A = torch.empty(divide(self.num_v_heads, self.attn_tp_size), dtype=torch.float32).uniform_( - 0, 16 - ) - self.A_log = nn.Parameter(torch.log(A)) - self.A_log._no_weight_decay = True - - set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - group_size=None, - norm_before_gate=True, - device=torch.cuda.current_device(), - dtype=config.torch_dtype, - ) - - self.out_proj = RowParallelLinear( - self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - reduce_results=False, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - ) - - def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): - """ - Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. - """ - new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - self.num_k_heads // self.attn_tp_size, - ( - self.head_k_dim - + self.head_k_dim - + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // self.num_k_heads - ), - ) - new_tensor_shape_ba = mixed_ba.size()[:-1] + ( - self.num_k_heads // self.attn_tp_size, - 2 * self.num_v_heads // self.num_k_heads, - ) - - mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - - split_arg_list_qkvz = [ - self.head_k_dim, - self.head_k_dim, - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - ] - split_arg_list_ba = [ - self.num_v_heads // self.num_k_heads, - self.num_v_heads // self.num_k_heads, - ] - - # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] - # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] - (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) - (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) - - # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - value = value.reshape(value.size(0), -1, self.head_v_dim) - z = z.reshape(z.size(0), -1, self.head_v_dim) - b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size) - a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size) - - return query, key, value, z, b, a - - def _forward_input_proj(self, hidden_states: torch.Tensor): - DUAL_STREAM_TOKEN_THRESHOLD = 1024 - seq_len, _ = hidden_states.shape - if seq_len < DUAL_STREAM_TOKEN_THRESHOLD: - current_stream = torch.cuda.current_stream() - self.alt_stream.wait_stream(current_stream) - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - with torch.cuda.stream(self.alt_stream): - projected_states_ba, _ = self.in_proj_ba(hidden_states) - current_stream.wait_stream(self.alt_stream) - else: - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - projected_states_ba, _ = self.in_proj_ba(hidden_states) - return projected_states_qkvz, projected_states_ba - - def forward( - self, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ): - seq_len, _ = hidden_states.shape - is_cuda_graph = forward_batch.forward_mode.is_cuda_graph() - - projected_states_qkvz, projected_states_ba = self._forward_input_proj(hidden_states) - - if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph: - mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( - projected_states_qkvz, - projected_states_ba, - triton.cdiv(self.num_k_heads, self.attn_tp_size), - triton.cdiv(self.num_v_heads, self.attn_tp_size), - self.head_k_dim, - self.head_v_dim, - ) - else: - query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba - ) - query, key, value = map(lambda x: x.reshape(x.shape[0], -1), (query, key, value)) - mixed_qkv = torch.cat((query, key, value), dim=-1) - # mixed_qkv = rearrange(mixed_qkv, "b l d -> b d l") - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ) - - kwargs = { - "mixed_qkv": mixed_qkv, - "conv_weights": conv_weights, - "bias": self.conv1d.bias, - "activation": self.activation, - "key_dim": self.key_dim, - "value_dim": self.value_dim, - "attention_tp_size": self.attn_tp_size, - "head_k_dim": self.head_k_dim, - "head_v_dim": self.head_v_dim, - "a": a, - "b": b, - "A_log": self.A_log, - "dt_bias": self.dt_bias, - "layer_id": self.layer_id, - "seq_len": seq_len, - "z": z, - } - - core_attn_out = forward_batch.attn_backend.forward( - q=None, - k=None, - v=None, - layer=None, - forward_batch=forward_batch, - **kwargs, - ) - - z_shape_og = z.shape - # reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) - - output, _ = self.out_proj(core_attn_out) - return output - - -class Qwen3HybridLinearDecoderLayer(nn.Module): - - def __init__( - self, - config: Qwen3NextConfig, - layer_id: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - alt_stream: Optional[torch.cuda.Stream] = None, - ) -> None: - super().__init__() - self.config = config - self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream) - - # Qwen3Next all layers are sparse and have no nextn now - self.is_layer_sparse = True - is_previous_layer_sparse = True - self.layer_id = layer_id - - self.layer_scatter_modes = LayerScatterModes.init_new( - layer_id=layer_id, - num_layers=config.num_hidden_layers, - is_layer_sparse=self.is_layer_sparse, - is_previous_layer_sparse=is_previous_layer_sparse, - ) - - if self.is_layer_sparse: - self.mlp = Qwen2MoeSparseMoeBlock( - layer_id=layer_id, - config=config, - quant_config=quant_config, - alt_stream=alt_stream, - ) - else: - self.mlp = Qwen2MoeMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.layer_communicator = LayerCommunicator( - layer_scatter_modes=self.layer_scatter_modes, - input_layernorm=self.input_layernorm, - post_attention_layernorm=self.post_attention_layernorm, - allow_reduce_scatter=True, - ) - - def forward( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - **kwargs, - ): - forward_batch = kwargs.get("forward_batch", None) - - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch - ) - - if not forward_batch.forward_mode.is_idle(): - hidden_states = self.linear_attn( - hidden_states, - forward_batch, - ) - # Fully Connected - hidden_states, residual = self.layer_communicator.prepare_mlp( - hidden_states, residual, forward_batch - ) - - use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(forward_batch) - hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) - - hidden_states, residual = self.layer_communicator.postprocess_layer( - hidden_states, residual, forward_batch - ) - - return hidden_states, residual - - -class Qwen3HybridAttentionDecoderLayer(nn.Module): - - def __init__( - self, - config: Qwen3NextConfig, - layer_id: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - alt_stream: Optional[torch.cuda.Stream] = None, - ) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.attn_tp_rank = get_attention_tp_rank() - self.attn_tp_size = get_attention_tp_size() - self.total_num_heads = config.num_attention_heads - assert self.total_num_heads % self.attn_tp_size == 0 - self.num_heads = self.total_num_heads // self.attn_tp_size - self.total_num_kv_heads = config.num_key_value_heads - if self.total_num_kv_heads >= self.attn_tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % self.attn_tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert self.attn_tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size) - self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = getattr(config, "rope_theta", 10000) - self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.rope_scaling = getattr(config, "rope_scaling", None) - self.partial_rotary_factor = config.partial_rotary_factor - self.layer_id = layer_id - - self.attn_output_gate = getattr(config, "attn_output_gate", True) - if self.attn_output_gate: - logger.warning_once("using attn output gate!") - - self.rotary_emb = get_rope( - head_size=self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - rope_scaling=self.rope_scaling, - base=self.rope_theta, - partial_rotary_factor=self.partial_rotary_factor, - is_neox_style=True, - dtype=torch.get_default_dtype(), # see impl of get_rope - ) - - self.qkv_proj = QKVParallelLinear( - config.hidden_size, - self.head_dim, - self.total_num_heads * (1 + self.attn_output_gate), - self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - ) - - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=False, - tp_rank=self.attn_tp_rank, - tp_size=self.attn_tp_size, - ) - - self.attn = RadixAttention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - layer_id=layer_id, - prefix=f"{prefix}.attn", - ) - - # Qwen3Next all layers are sparse and have no nextn now - self.is_layer_sparse = True - is_previous_layer_sparse = True - - self.layer_scatter_modes = LayerScatterModes.init_new( - layer_id=layer_id, - num_layers=config.num_hidden_layers, - is_layer_sparse=self.is_layer_sparse, - is_previous_layer_sparse=is_previous_layer_sparse, - ) - - if self.is_layer_sparse: - self.mlp = Qwen2MoeSparseMoeBlock( - layer_id=layer_id, - config=config, - quant_config=quant_config, - alt_stream=alt_stream, - ) - else: - self.mlp = Qwen2MoeMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + from sglang.srt.model_executor.forward_batch_info import PPProxyTensors + from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import add_prefix, is_cuda, make_layers - self.layer_communicator = LayerCommunicator( - layer_scatter_modes=self.layer_scatter_modes, - input_layernorm=self.input_layernorm, - post_attention_layernorm=self.post_attention_layernorm, - allow_reduce_scatter=True, + try: + import sglang.srt.models.qwen3_next as m + except Exception as e: # Fallback: keep current module as-is + logger.warning( + f"Failed to import sglang.srt.models.qwen3_next for monkey patch: {e}. Using local copy." ) + return - self.alt_stream = alt_stream - - def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - # overlap qk norm - if self.alt_stream is not None and get_is_capture_mode(): - current_stream = torch.cuda.current_stream() - self.alt_stream.wait_stream(current_stream) - q_by_head = q.reshape(-1, self.head_dim) - q_by_head = self.q_norm(q_by_head) - with torch.cuda.stream(self.alt_stream): - k_by_head = k.reshape(-1, self.head_dim) - k_by_head = self.k_norm(k_by_head) - current_stream.wait_stream(self.alt_stream) - else: - q_by_head = q.reshape(-1, self.head_dim) - q_by_head = self.q_norm(q_by_head) - k_by_head = k.reshape(-1, self.head_dim) - k_by_head = self.k_norm(k_by_head) - q = q_by_head.view(q.shape) - k = k_by_head.view(k.shape) - return q, k - - def self_attention( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - - if self.attn_output_gate: - q_gate, k, v = qkv.split([self.q_size * 2, self.kv_size, self.kv_size], dim=-1) - orig_shape = q_gate.shape[:-1] - q_gate = q_gate.view(*orig_shape, self.num_heads, -1) - q, gate = torch.chunk(q_gate, 2, dim=-1) - q = q.reshape(*orig_shape, -1) - gate = gate.reshape(*orig_shape, -1) - else: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - q, k = self._apply_qk_norm(q, k) - - q, k = self.rotary_emb(positions, q, k) - - attn_output = self.attn(q, k, v, forward_batch) - - if self.attn_output_gate: - gate = torch.sigmoid(gate) - attn_output = attn_output * gate - - output, _ = self.o_proj(attn_output) - return output - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - forward_batch: ForwardBatch, - **kwargs: Any, + # --- Patch Qwen3NextModel.__init__ --- + def _pp_model_init( + self, config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "" ): - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch - ) - - if not forward_batch.forward_mode.is_idle(): - hidden_states = self.self_attention( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) - - # Fully Connected - hidden_states, residual = self.layer_communicator.prepare_mlp( - hidden_states, residual, forward_batch - ) - use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(forward_batch) - hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) - - hidden_states, residual = self.layer_communicator.postprocess_layer( - hidden_states, residual, forward_batch - ) - - return hidden_states, residual - - -ALL_DECODER_LAYER_TYPES = { - "attention": Qwen3HybridAttentionDecoderLayer, - "linear_attention": Qwen3HybridLinearDecoderLayer, -} - - -class Qwen3NextModel(nn.Module): - def __init__( - self, - config: Qwen3NextConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() + nn.Module.__init__(self) self.config = config self.pp_group = get_pp_group() - - alt_stream = torch.cuda.Stream() if _is_cuda else None + alt_stream = torch.cuda.Stream() if is_cuda() else None if self.pp_group.is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -764,7 +62,7 @@ def __init__( self.embed_tokens = PPMissingLayer() def get_layer(idx: int, prefix: str): - layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]] + layer_class = m.ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]] return layer_class( config, idx, @@ -786,27 +84,25 @@ def get_layer(idx: int, prefix: str): self.norm = PPMissingLayer(return_tuple=True) self.infer_count = 0 - def forward( + # --- Patch Qwen3NextModel.forward --- + def _pp_model_forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - forward_batch: ForwardBatch, - # mamba_cache_params: MambaCacheParams, + forward_batch, inputs_embeds: Optional[torch.Tensor] = None, pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> torch.Tensor: - - # pass a sequence index tensor, that is required for - # proper continuous batching computation including - # chunked prefill + **kwargs, + ): if self.pp_group.is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = ( + inputs_embeds if inputs_embeds is not None else self.embed_tokens(input_ids) + ) residual = None else: - assert pp_proxy_tensors is not None + assert ( + pp_proxy_tensors is not None + ), "pp_proxy_tensors must be provided on non-first PP ranks" hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] @@ -819,217 +115,111 @@ def forward( residual=residual, forward_batch=forward_batch, ) + if not self.pp_group.is_last_rank: - return PPProxyTensors( - { - "hidden_states": hidden_states, - "residual": residual, - } - ) + return PPProxyTensors({"hidden_states": hidden_states, "residual": residual}) else: if hidden_states.shape[0] != 0: if residual is None: hidden_states = self.norm(hidden_states) else: hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states - return hidden_states - - -class HybridLayerType(enum.Enum): - full_attention = "attention" - swa_attention = "swa_attention" - linear_attention = "linear_attention" - mamba2 = "mamba" - - -class Qwen3NextForCausalLM(nn.Module): - fall_back_to_pt_during_load = False - - def __init__( + # --- Patch Qwen3NextForCausalLM.__init__ (remove single-rank assert) --- + def _pp_for_causal_init( self, config: Qwen3NextConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ) -> None: - super().__init__() + ): + nn.Module.__init__(self) self.config = config self.pp_group = get_pp_group() - # logger.info(f"self.pp_group.ranks: {self.pp_group.ranks}") - # logger.info(f"self.pp_group.rank: {self.pp_group.rank}") - # logger.info(f"PP first - last: {self.pp_group.first_rank} - {self.pp_group.last_rank}") - # logger.info(f"PP rank: {self.pp_group.rank}") - # logger.info(f"PP is_first: {self.pp_group.is_first_rank}") - # logger.info(f"PP is_last: {self.pp_group.is_last_rank}") - # assert self.pp_group.is_first_rank and self.pp_group.is_last_rank self.quant_config = quant_config - self.model = Qwen3NextModel(config, quant_config, prefix=add_prefix("model", prefix)) + self.model = m.Qwen3NextModel(config, quant_config, prefix=add_prefix("model", prefix)) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, org_num_embeddings=config.vocab_size, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], - ) - self.lm_head = self.lm_head.float() - self.logits_processor = LogitsProcessor(config) + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ).float() + self.logits_processor = m.LogitsProcessor(config) + # --- Patch Qwen3NextForCausalLM.forward --- @torch.no_grad() - def forward( + def _pp_for_causal_forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - forward_batch: ForwardBatch, + forward_batch, inputs_embeds: Optional[torch.Tensor] = None, pp_proxy_tensors: Optional[PPProxyTensors] = None, **kwargs, ): hidden_states = self.model( - input_ids, positions, forward_batch, inputs_embeds, pp_proxy_tensors + input_ids, + positions, + forward_batch, + inputs_embeds, + pp_proxy_tensors, ) - if self.pp_group.is_last_rank: return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch) else: return hidden_states - def get_embed_and_head(self): - return self.model.embed_tokens.weight, self.lm_head.weight - - def set_embed_and_head(self, embed, head): - del self.model.embed_tokens.weight - del self.lm_head.weight - self.model.embed_tokens.weight = embed - self.lm_head.weight = head - torch.cuda.empty_cache() - torch.cuda.synchronize() + # --- Patch Qwen3NextForCausalLM.load_weights (filter by PP slice) --- + orig_load_weights = m.Qwen3NextForCausalLM.load_weights - def load_weights( + def _pp_load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False ) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, - ) - + """Filter incoming weights to only those relevant for this PP slice. + + Rules: + - Layer weights: keep only if layer_id in [start, end). + - Non-layer weights (layer_id is None): + * keep if they correspond to names present in current params_dict (e.g., model.norm on last rank, + embed on first rank, lm_head on all ranks), or + * keep if they match known mapping keywords (so original loader can rename and resolve), or + * keep if they are explicitly skipped by original loader (e.g., rotary_emb.inv_freq), harmless to pass. + This prevents KeyError like 'model.norm.weight' on non-last ranks where norm is a PPMissingLayer. + """ + start = getattr(self.model, "start_layer", None) + end = getattr(self.model, "end_layer", None) params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - # preid = -1 - for name, loaded_weight in weights: + mapping_keywords = ( + "q_proj", + "k_proj", + "v_proj", + "gate_proj", + "up_proj", + "down_proj", + "self_attn", + ) + + filtered: list[tuple[str, torch.Tensor]] = [] + for name, w in weights: layer_id = get_layer_id(name) - if ( - layer_id is not None - and hasattr(self.model, "start_layer") - and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) - ): - continue - - if is_mtp: - - if "mtp" not in name: - continue - - if name in [ - "mtp.fc.weight", - "mtp.pre_fc_norm_embedding.weight", - "mtp.pre_fc_norm_hidden.weight", - ]: - name = name.replace("mtp.", "") - else: - name = name.replace("mtp", "model") - - if not is_mtp and "mtp" in name: - continue - - if "rotary_emb.inv_freq" in name: - continue - - if ".self_attn." in name: - name = name.replace(".self_attn", "") - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - # TODO(fix mtp loading) - if "mlp.experts" in name: - continue - - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - # if is_pp_missing_parameter(name, self): - # continue - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader") - weight_loader(param, loaded_weight, shard_id) - break + if layer_id is not None: + if (start is None) or (start <= layer_id < end): + filtered.append((name, w)) else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - # if is_pp_missing_parameter(name, self): - # continue - # Skip loading extra bias for GPTQ models. - if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: - continue - - param = params_dict[name] - - weight_loader = getattr(param, "weight_loader") - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # if is_pp_missing_parameter(name, self): - # continue - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - @classmethod - def get_model_config_for_expert_location(cls, config): - return ModelConfigForExpertLocation( - num_layers=config.num_hidden_layers, - num_logical_experts=config.num_experts, - num_groups=None, - ) - - -EntryClass = Qwen3NextForCausalLM + if ( + (name in params_dict) + or any(k in name for k in mapping_keywords) + or ("rotary_emb.inv_freq" in name) + ): + filtered.append((name, w)) + + return orig_load_weights(self, filtered, is_mtp=is_mtp) + + # Bind patches + m.Qwen3NextModel.__init__ = _pp_model_init # type: ignore + m.Qwen3NextModel.forward = _pp_model_forward # type: ignore + m.Qwen3NextForCausalLM.__init__ = _pp_for_causal_init # type: ignore + m.Qwen3NextForCausalLM.forward = _pp_for_causal_forward # type: ignore + m.Qwen3NextForCausalLM.load_weights = _pp_load_weights # type: ignore diff --git a/src/parallax/sglang/monkey_patch/triton_backend.py b/src/parallax/sglang/monkey_patch/triton_backend.py index 394177ca..02cd81df 100644 --- a/src/parallax/sglang/monkey_patch/triton_backend.py +++ b/src/parallax/sglang/monkey_patch/triton_backend.py @@ -4,7 +4,7 @@ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.utils import get_bool_env_var, get_device_core_count +from sglang.srt.utils import get_bool_env_var, get_device_core_count, get_int_env_var def parallax_triton_backend_init( @@ -35,15 +35,35 @@ def parallax_triton_backend_init( self.num_head = model_runner.model_config.num_attention_heads // get_attention_tp_size() self.num_kv_head = model_runner.model_config.get_num_kv_heads(get_attention_tp_size()) # Modifies layer id to support pipeline parallel - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( - model_runner.pp_start_layer - ).shape[-1] + if model_runner.hybrid_gdn_config is not None: + # For hybrid linear models, layer_id = 0 may not be full attention + self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() + else: + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( + model_runner.pp_start_layer + ).shape[-1] self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.device_core_count = get_device_core_count(model_runner.gpu_id) self.static_kv_splits = get_bool_env_var("SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false") self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + # Decide whether enable deterministic inference with batch-invariant operations + self.enable_deterministic = model_runner.server_args.enable_deterministic_inference + + # Configure deterministic inference settings + if self.enable_deterministic: + # Use fixed split tile size for batch invariance + self.split_tile_size = get_int_env_var("SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256) + # Set static_kv_splits to False to use deterministic logic instead + self.static_kv_splits = False + else: + self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size + + if self.split_tile_size is not None: + self.max_kv_splits = ( + self.max_context_len + self.split_tile_size - 1 + ) // self.split_tile_size # Check arguments assert not ( model_runner.sliding_window_size is not None @@ -75,7 +95,11 @@ def parallax_triton_backend_init( self.mask_indptr = torch.zeros((max_bs + 1,), dtype=torch.int64, device=model_runner.device) # Initialize forward metadata - self.forward_metadata = None + from sglang.srt.layers.attention.triton_backend import ForwardMetadata + + self.forward_metadata: ForwardMetadata = None + + self.cuda_graph_custom_mask = None def apply_triton_backend_init_monkey_patch():