Skip to content
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ mac = [
gpu = [
"mlx-lm==0.28.0",
"mlx[cpu]==0.29.1",
"sglang[all]==0.5.2",
"sglang[all]==0.5.4.post1",
]

benchmark = [
Expand Down
6 changes: 4 additions & 2 deletions src/parallax/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,14 @@ async def v1_chat_completions(raw_request: fastapi.Request):
# Check if request_json has "rid", otherwise generate new one
request_id = request_json.get("rid")
if request_id is None:
request_id = uuid.uuid4()
request_json["rid"] = str(request_id)
request_id = str(uuid.uuid4())
request_json["rid"] = request_id

app.state.http_handler.create_request(request_json)
app.state.http_handler.send_request(request_json)
req = app.state.http_handler.processing_requests.get(request_id)
if req is None:
return create_error_response("Request not found", "RequestNotFoundError")
is_stream = req.stream

if is_stream:
Expand Down
10 changes: 9 additions & 1 deletion src/parallax/sglang/batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
"""

from types import SimpleNamespace
from typing import List

import torch
Expand Down Expand Up @@ -67,11 +68,16 @@ def form_sgl_batch_prefill(
) -> ForwardBatch:
"""Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow"""
sgl_reqs = transform_requests_to_sglang(requests)
dummy_tree_cache = SimpleNamespace(
page_size=model_runner.server_args.page_size,
device=model_runner.device,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
)
schedule_batch = ScheduleBatch.init_new(
reqs=sgl_reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
tree_cache=None,
tree_cache=dummy_tree_cache,
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
Expand Down Expand Up @@ -193,6 +199,8 @@ def form_sgl_batch_decode(

def release_cuda_request(running_batch: ScheduleBatch, request_id: str):
"""Release KV Cache and other resources for finished/aborted requests."""
if running_batch is None or running_batch.is_empty():
return
seq_lens_cpu = running_batch.seq_lens.cpu().numpy()
idx = find_index(running_batch, request_id)
req = running_batch.reqs.pop(idx)
Expand Down
24 changes: 16 additions & 8 deletions src/parallax/sglang/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import os
import random
import sys
from typing import Any, Dict, List, Optional, Tuple, Union

import sglang
Expand Down Expand Up @@ -70,6 +69,7 @@ def __init__(
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_npu_communicator: bool,
use_torch_symm_mem: bool = False,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
pp_start_layer: int = 0,
Expand All @@ -87,6 +87,7 @@ def __init__(
use_hpu_communicator=use_hpu_communicator,
use_xpu_communicator=use_xpu_communicator,
use_npu_communicator=use_npu_communicator,
use_torch_symm_mem=use_torch_symm_mem,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
Expand Down Expand Up @@ -437,7 +438,7 @@ def monkey_patch_make_layers(
# circula imports
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.offloader import get_offloader
from sglang.srt.utils.offloader import get_offloader

assert not pp_size or num_hidden_layers >= pp_size
start_layer, end_layer = get_pp_group().pp_start_layer, get_pp_group().pp_end_layer
Expand All @@ -460,15 +461,15 @@ def monkey_patch_make_layers(

## TODO: Move this when sgalang supports qwen3_next pipeline parallelism
def monkey_patch_qwen3_next():
from parallax.sglang.monkey_patch import (
qwen3_next_model as parallax_qwen3_next_model_module,
)
from parallax.sglang.monkey_patch.qwen3_next_config import (
monkey_patch_linear_layer_ids,
apply_qwen3_next_config_monkey_patch,
)
from parallax.sglang.monkey_patch.qwen3_next_model import (
apply_qwen3_next_monkey_patch,
)

sys.modules["sglang.srt.models.qwen3_next"] = parallax_qwen3_next_model_module
sglang.srt.configs.qwen3_next.Qwen3NextConfig.linear_layer_ids = monkey_patch_linear_layer_ids
apply_qwen3_next_monkey_patch()
apply_qwen3_next_config_monkey_patch()


## TODO: Move this when sgalang supports gpt_oss pipeline parallelism
Expand Down Expand Up @@ -553,6 +554,11 @@ def initialize_sgl_model_runner(
attention_backend = "triton"
moe_runner_backend = "triton_kernel"

architectures = config.get("architectures", [])
if architectures and any("Qwen3Next" in arch for arch in architectures):
logger.debug(f"Qwen3-Next model detected, setting kv_block_size to 1")
kv_block_size = 1

server_args = form_sgl_server_args(
original_model_path,
dtype,
Expand All @@ -574,8 +580,10 @@ def initialize_sgl_model_runner(
model_config.hf_config.tie_word_embeddings = False
model_config.hf_config.start_layer = start_layer
model_config.hf_config.end_layer = end_layer

logger.debug(f"model_start_layer: {model_config.hf_config.start_layer}")
logger.debug(f"model_end_layer: {model_config.hf_config.end_layer}")

model_runner = ParallaxModelRunner(
model_config=model_config,
mem_fraction_static=kv_cache_memory_fraction,
Expand Down
57 changes: 42 additions & 15 deletions src/parallax/sglang/monkey_patch/qwen3_next_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3Hybrid model configuration"""

import enum
Expand All @@ -31,10 +17,51 @@ class HybridLayerType(enum.Enum):

@property
def monkey_patch_linear_layer_ids(self):
return [
"""Return linear-attention layer ids restricted to the PP slice.

This is intended to be bound as a property on
`sglang.srt.configs.qwen3_next.Qwen3NextConfig`.
"""
lst = [
i
for i, type_value in enumerate(self.layers_block_type)
if type_value == HybridLayerType.linear_attention.value
and i >= self.start_layer
and i < self.end_layer
]
# If no matching layer id, return at least [-1]
# just for pp
return lst if lst else [-1]


@property
def monkey_patch_full_attention_layer_ids(self):
"""Return full-attention layer ids restricted to the PP slice.

This is intended to be bound as a property on
`sglang.srt.configs.qwen3_next.Qwen3NextConfig`.
"""
lst = [
i
for i, type_value in enumerate(self.layers_block_type)
if type_value == HybridLayerType.full_attention.value
and i >= self.start_layer
and i < self.end_layer
]
# If no matching layer id, return at least [-1]
# just for pp
return lst if lst else [-1]


def apply_qwen3_next_config_monkey_patch():
"""Bind monkey-patch helpers to the upstream Qwen3NextConfig class.

We attach the two helpers above as properties so callers can access
`config.linear_layer_ids` / `config.full_attention_layer_ids` the same
way upstream expects.
"""

import sglang.srt.configs.qwen3_next as s

s.Qwen3NextConfig.linear_layer_ids = monkey_patch_linear_layer_ids
s.Qwen3NextConfig.full_attention_layer_ids = monkey_patch_full_attention_layer_ids
Loading