From ecd40f8028bbbb28bf7553ed99e064fe51218621 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Mon, 29 Sep 2025 19:07:06 +0800 Subject: [PATCH 1/7] fix: mlx model name map & fix decode wrong seq length for hidden states --- src/parallax/server/executor.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index e6638d0d..b3a413b1 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -62,6 +62,13 @@ logger = get_logger(__name__) +"""Currently hard code model name for MAC""" +MLX_MODEL_NAME_MAP = { + "openai/gpt-oss-20b": "mlx-community/gpt-oss-20b-MXFP4-Q8", + "openai/gpt-oss-120b": "mlx-community/gpt-oss-120b-MXFP4-Q8", +} + + class Executor: """High-level executor for managing model shards, scheduler, and cache pool on each Peer.""" @@ -128,6 +135,9 @@ def __init__( self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) self.cur_batch = None else: + mlx_model_repo = MLX_MODEL_NAME_MAP.get(model_repo, None) + if mlx_model_repo is not None: + model_repo = mlx_model_repo logger.debug( f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})" ) @@ -948,12 +958,21 @@ def _prepare_next_batch_requests( hidden_state_for_req = hidden_states[i : i + 1] else: # Other peers get a 3D array of hidden states - true_length = int(lengths[i]) - if hidden_states.ndim == 3: - hidden_state_for_req = hidden_states[i, :true_length, :] + if src_request.is_prefill: + true_length = int(lengths[i]) + if hidden_states.ndim == 3: + hidden_state_for_req = hidden_states[i, :true_length, :] + else: + hidden_state_for_req = hidden_states[ + pre_length : pre_length + true_length, : + ] + pre_length += true_length else: - hidden_state_for_req = hidden_states[pre_length : pre_length + true_length, :] - pre_length += true_length + if hidden_states.ndim == 3: + hidden_state_for_req = hidden_states[i, :, :] + else: + hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] + pre_length += 1 next_req = self._prepare_next_single_request(src_request, hidden_state_for_req) batched_requests.append(next_req) From 86c51d0e30ba0fa67ec476376a05026561d18aea Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Mon, 29 Sep 2025 19:34:22 +0800 Subject: [PATCH 2/7] update --- src/parallax/launch.py | 19 +++++++++++++++++++ src/parallax/server/executor.py | 10 ---------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 14674f7e..29740a7a 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -21,10 +21,17 @@ from parallax.server.executor import Executor from parallax.server.http_server import launch_http_server from parallax.server.server_args import parse_args +from parallax.utils.utils import get_current_device from parallax_utils.logging_config import get_logger logger = get_logger("parallax.launch") +"""Currently hard code model name for MAC""" +MLX_MODEL_NAME_MAP = { + "openai/gpt-oss-20b": "mlx-community/gpt-oss-20b-MXFP4-Q8", + "openai/gpt-oss-120b": "mlx-community/gpt-oss-120b-MXFP4-Q8", +} + if __name__ == "__main__": multiprocessing.set_start_method("spawn", force=True) try: @@ -39,6 +46,12 @@ logger.debug(f"executor_input_addr: {args.executor_input_ipc}") logger.debug(f"executor_output_addr: {args.executor_output_ipc}") gradient_server = None + # Hard code for mlx-community models + if get_current_device() == "mlx": + mlx_model_repo = MLX_MODEL_NAME_MAP.get(args.model_path, None) + if mlx_model_repo is not None: + args.model_path = mlx_model_repo + logger.debug(f"Replace mlx model path: {mlx_model_repo}") if args.scheduler_addr is None: # only launch http server on head node if args.start_layer == 0: @@ -86,6 +99,12 @@ args.start_layer = gradient_server.block_start_index args.end_layer = gradient_server.block_end_index args.model_path = gradient_server.model_name + # Hard code for mlx-community models + if get_current_device() == "mlx": + mlx_model_repo = MLX_MODEL_NAME_MAP.get(args.model_path, None) + if mlx_model_repo is not None: + args.model_path = mlx_model_repo + logger.debug(f"Replace mlx model path: {mlx_model_repo}") logger.debug( f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}" ) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index b3a413b1..637f7ffd 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -62,13 +62,6 @@ logger = get_logger(__name__) -"""Currently hard code model name for MAC""" -MLX_MODEL_NAME_MAP = { - "openai/gpt-oss-20b": "mlx-community/gpt-oss-20b-MXFP4-Q8", - "openai/gpt-oss-120b": "mlx-community/gpt-oss-120b-MXFP4-Q8", -} - - class Executor: """High-level executor for managing model shards, scheduler, and cache pool on each Peer.""" @@ -135,9 +128,6 @@ def __init__( self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) self.cur_batch = None else: - mlx_model_repo = MLX_MODEL_NAME_MAP.get(model_repo, None) - if mlx_model_repo is not None: - model_repo = mlx_model_repo logger.debug( f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})" ) From 9d21ae783edcd3105abf7a0ed82fbc2af8d1e7d5 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Tue, 30 Sep 2025 11:50:55 +0800 Subject: [PATCH 3/7] add sglang gpt-oss monkey patch --- src/parallax/sglang/model_runner.py | 8 + .../sglang/monkey_patch/gpt_oss_model.py | 184 ++++++++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 src/parallax/sglang/monkey_patch/gpt_oss_model.py diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index b043bdef..58c7977a 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -470,6 +470,13 @@ def monkey_patch_qwen3_next(): sglang.srt.configs.qwen3_next.Qwen3NextConfig.linear_layer_ids = monkey_patch_linear_layer_ids +## TODO: Move this when sgalang supports gpt_oss pipeline parallelism +def monkey_patch_gpt_oss(): + from parallax.sglang.monkey_patch.gpt_oss_model import apply_gpt_oss_monkey_patch + + apply_gpt_oss_monkey_patch() + + def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", @@ -500,6 +507,7 @@ def apply_parallax_monkey_patch(): ) sglang.srt.utils.make_layers = monkey_patch_make_layers monkey_patch_qwen3_next() + monkey_patch_gpt_oss() def initialize_sgl_model_runner( diff --git a/src/parallax/sglang/monkey_patch/gpt_oss_model.py b/src/parallax/sglang/monkey_patch/gpt_oss_model.py new file mode 100644 index 00000000..b9d7417d --- /dev/null +++ b/src/parallax/sglang/monkey_patch/gpt_oss_model.py @@ -0,0 +1,184 @@ +import math + +from sglang.srt.distributed import ( + get_moe_expert_parallel_rank, + get_moe_expert_parallel_world_size, + get_moe_tensor_parallel_rank, + get_moe_tensor_parallel_world_size, +) +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.gpt_oss import GptOssForCausalLM + + +def _parallax_load_mxfp4_experts_weights(self, weights): + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + mxfp4_block = 32 + + moe_tp_rank = get_moe_tensor_parallel_rank() + moe_tp_size = get_moe_tensor_parallel_world_size() + moe_ep_rank = get_moe_expert_parallel_rank() + moe_ep_size = get_moe_expert_parallel_world_size() + + intermediate_size = self.config.intermediate_size + assert ( + intermediate_size % mxfp4_block == 0 + ), f"{intermediate_size=} must be divisible by {mxfp4_block=}" + intermediate_size_block = intermediate_size // mxfp4_block + + per_rank_intermediate_size_block = math.ceil(intermediate_size_block / moe_tp_size) + + per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block + + # Calculate common slicing bounds for current rank + assert self.config.num_local_experts % moe_ep_size == 0 + moe_num_global_experts = self.config.num_local_experts + moe_num_local_experts = self.config.num_local_experts // moe_ep_size + + moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size + moe_tp_rank_end = min((moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size) + + moe_ep_rank_start = moe_ep_rank * moe_num_local_experts + moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts + + for name, weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) + ): + continue + weight = weight.cuda() + + if "gate_up_proj_blocks" in name: + # Handle MLP gate and up projection weights + new_name = name.replace("gate_up_proj_blocks", "w13_weight") + + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(moe_num_global_experts, 2 * intermediate_size, -1).contiguous() + + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ..., + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_blocks" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_blocks", "w2_weight") + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(moe_num_global_experts, -1, intermediate_size // 2).contiguous() + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + ..., + moe_tp_rank_start // 2 : moe_tp_rank_end // 2, + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "gate_up_proj_scales" in name: + # Handle MLP gate and up projection weights scale + new_name = name.replace("gate_up_proj_scales", "w13_weight_scale") + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ..., + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_scales" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_scales", "w2_weight_scale") + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + ..., + moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block, + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_weight_bias") + + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...] + if moe_tp_rank != 0: + narrow_weight = torch.zeros_like(narrow_weight) + + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_weight_bias") + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + return loaded_params + + +def apply_gpt_oss_monkey_patch(): + GptOssForCausalLM._load_mxfp4_experts_weights = _parallax_load_mxfp4_experts_weights From e954961e35108f9c73908cd668363cd67a4d4706 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Tue, 30 Sep 2025 12:29:46 +0800 Subject: [PATCH 4/7] Fix gpt-oss sgl arguments --- src/parallax/sglang/model_runner.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 58c7977a..13038757 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -480,7 +480,7 @@ def monkey_patch_gpt_oss(): def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", - attention_backend: str = "torch_native", + attention_backend: str = "flashinfer", kv_block_size: int = 64, moe_runner_backend="auto", ): @@ -533,6 +533,15 @@ def initialize_sgl_model_runner( dtype = config.get("torch_dtype", "bfloat16") nccl_port = random.randint(4000, 5000) + # Handling mxfp4 arguments + quant_method = config.get("quant_method", None) + quantization_config = config.get("quantization_config", None) + if quant_method is None and quantization_config is not None: + quant_method = quantization_config.get("quant_method", None) + if quant_method == "mxfp4": + attention_backend = "triton" + moe_runner_backend = "triton_kernel" + server_args = form_sgl_server_args( original_model_path, dtype, From 55938055332bfc1aba5fcb5e91032e8e40df7cbd Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Tue, 30 Sep 2025 13:23:50 +0800 Subject: [PATCH 5/7] Add triton backend suport pipeline parallel patch --- src/parallax/sglang/model_runner.py | 10 +++ .../sglang/monkey_patch/triton_backend.py | 84 +++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 src/parallax/sglang/monkey_patch/triton_backend.py diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 13038757..b7df739d 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -477,6 +477,15 @@ def monkey_patch_gpt_oss(): apply_gpt_oss_monkey_patch() +## TODO: Move this when sgalang supports triton backend pipeline parallelism +def monkey_patch_triton_backend_init(): + from parallax.sglang.monkey_patch.triton_backend import ( + apply_triton_backend_init_monkey_patch, + ) + + apply_triton_backend_init_monkey_patch() + + def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", @@ -508,6 +517,7 @@ def apply_parallax_monkey_patch(): sglang.srt.utils.make_layers = monkey_patch_make_layers monkey_patch_qwen3_next() monkey_patch_gpt_oss() + monkey_patch_triton_backend_init() def initialize_sgl_model_runner( diff --git a/src/parallax/sglang/monkey_patch/triton_backend.py b/src/parallax/sglang/monkey_patch/triton_backend.py new file mode 100644 index 00000000..f3bd35a8 --- /dev/null +++ b/src/parallax/sglang/monkey_patch/triton_backend.py @@ -0,0 +1,84 @@ +from typing import Optional + +import torch +from sglang.srt.layers.attention.triton_backend import TritonAttnBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.utils import get_bool_env_var, get_device_core_count + + +def parallax_triton_backend_init( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, +): + # Lazy import to avoid the initialization of cuda context + from sglang.srt.layers.attention.triton_ops.decode_attention import ( + decode_attention_fwd, + ) + from sglang.srt.layers.attention.triton_ops.extend_attention import ( + extend_attention_fwd, + ) + + super().__init__() + + self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd) + self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) + + # Parse args + self.skip_prefill = skip_prefill + max_bs = model_runner.req_to_token_pool.size + self.sliding_window_size = model_runner.sliding_window_size + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + self.speculative_num_steps = model_runner.server_args.speculative_num_steps + self.num_head = model_runner.model_config.num_attention_heads // get_attention_tp_size() + self.num_kv_head = model_runner.model_config.get_num_kv_heads(get_attention_tp_size()) + # Modifies layer id to support pipeline parallel + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( + model_runner.pp_start_layer + ).shape[-1] + self.max_context_len = model_runner.model_config.context_len + self.device = model_runner.device + self.device_core_count = get_device_core_count(model_runner.gpu_id) + self.static_kv_splits = get_bool_env_var("SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false") + self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + + # Check arguments + assert not ( + model_runner.sliding_window_size is not None + and model_runner.model_config.is_encoder_decoder + ), "Sliding window and cross attention are not supported together" + + # Initialize buffers + # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) + else: + self.kv_indptr = kv_indptr_buf + + # If sliding window is enabled, we might need two sets of buffers + # because of interleaved attention types (e.g. for Gemma3) + self.window_kv_indptr = None + if self.sliding_window_size is not None and self.sliding_window_size > 0: + if kv_indptr_buf is None: + self.window_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + # When provided a buffer, create a clone for the second buffer + self.window_kv_indptr = torch.zeros_like(kv_indptr_buf) + + if not self.skip_prefill: + self.qo_indptr = torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) + + self.mask_indptr = torch.zeros((max_bs + 1,), dtype=torch.int64, device=model_runner.device) + + # Initialize forward metadata + self.forward_metadata = None + + +def apply_triton_backend_init_monkey_patch(): + TritonAttnBackend.__init__ = parallax_triton_backend_init From dc265a8d20255b962ae65cf7e539f950a3619c07 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Tue, 30 Sep 2025 13:35:33 +0800 Subject: [PATCH 6/7] Update triton_backend.py --- src/parallax/sglang/monkey_patch/triton_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/parallax/sglang/monkey_patch/triton_backend.py b/src/parallax/sglang/monkey_patch/triton_backend.py index f3bd35a8..1c4c5a99 100644 --- a/src/parallax/sglang/monkey_patch/triton_backend.py +++ b/src/parallax/sglang/monkey_patch/triton_backend.py @@ -21,8 +21,6 @@ def parallax_triton_backend_init( extend_attention_fwd, ) - super().__init__() - self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd) self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) @@ -82,3 +80,4 @@ def parallax_triton_backend_init( def apply_triton_backend_init_monkey_patch(): TritonAttnBackend.__init__ = parallax_triton_backend_init + From 72ad4bb92ac0555209e102cac7a60b7ad95f670f Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Tue, 30 Sep 2025 13:38:20 +0800 Subject: [PATCH 7/7] fix pre commit check --- src/parallax/sglang/monkey_patch/triton_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parallax/sglang/monkey_patch/triton_backend.py b/src/parallax/sglang/monkey_patch/triton_backend.py index 1c4c5a99..394177ca 100644 --- a/src/parallax/sglang/monkey_patch/triton_backend.py +++ b/src/parallax/sglang/monkey_patch/triton_backend.py @@ -80,4 +80,3 @@ def parallax_triton_backend_init( def apply_triton_backend_init_monkey_patch(): TritonAttnBackend.__init__ = parallax_triton_backend_init -