From 08ec04eb2213d25a3ccf9eea11fc32dd618ca198 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Thu, 30 Oct 2025 16:00:38 +0800 Subject: [PATCH 1/6] update --- src/parallax/sglang/model_runner.py | 319 +----------------- src/parallax/sglang/monkey_patch.py | 27 ++ .../glm4_moe_model.py | 0 .../gpt_oss_model.py | 0 .../minimax_m2_model.py | 0 .../monkey_patch_utils/model_parallel.py | 311 +++++++++++++++++ .../qwen3_next_config.py | 0 .../qwen3_next_model.py | 0 .../triton_backend.py | 0 9 files changed, 339 insertions(+), 318 deletions(-) create mode 100644 src/parallax/sglang/monkey_patch.py rename src/parallax/sglang/{monkey_patch => monkey_patch_utils}/glm4_moe_model.py (100%) rename src/parallax/sglang/{monkey_patch => monkey_patch_utils}/gpt_oss_model.py (100%) rename src/parallax/sglang/{monkey_patch => monkey_patch_utils}/minimax_m2_model.py (100%) create mode 100644 src/parallax/sglang/monkey_patch_utils/model_parallel.py rename src/parallax/sglang/{monkey_patch => monkey_patch_utils}/qwen3_next_config.py (100%) rename src/parallax/sglang/{monkey_patch => monkey_patch_utils}/qwen3_next_model.py (100%) rename src/parallax/sglang/{monkey_patch => monkey_patch_utils}/triton_backend.py (100%) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 0692dd34..ad3e4130 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -43,69 +43,13 @@ from torch.distributed import Backend from parallax.utils.tokenizer_utils import load_tokenizer - -# from parallax.sglang.monkey_patch.model_runner import ModelRunner as SGLModelRunner +from parallax.sglang.monkey_patch import apply_parallax_monkey_patch logger = logging.getLogger(__name__) _is_cpu_amx_available = cpu_has_amx_support() -class ParallaxGroupCoordinator(SGLGroupCoordinator): - """ - Parallax GroupCoordinator module. - pp_start_layer, pp_end_layer, hidden_layers are necessary for decentralized inference. - Also change the definition of first_rank/last_rank. - """ - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - use_pynccl: bool, - use_pymscclpp: bool, - use_custom_allreduce: bool, - use_hpu_communicator: bool, - use_xpu_communicator: bool, - use_npu_communicator: bool, - use_torch_symm_mem: bool = False, - use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, - pp_start_layer: int = 0, - pp_end_layer: int = 0, - hidden_layers: int = 0, - ): - """Add pp_start_layer, pp_end_layer, hidden_layers for decentralized model""" - super().__init__( - group_ranks=group_ranks, - local_rank=local_rank, - torch_distributed_backend=torch_distributed_backend, - use_pynccl=use_pynccl, - use_pymscclpp=use_pymscclpp, - use_custom_allreduce=use_custom_allreduce, - use_hpu_communicator=use_hpu_communicator, - use_xpu_communicator=use_xpu_communicator, - use_npu_communicator=use_npu_communicator, - use_torch_symm_mem=use_torch_symm_mem, - use_message_queue_broadcaster=use_message_queue_broadcaster, - group_name=group_name, - ) - self.pp_start_layer = pp_start_layer - self.pp_end_layer = pp_end_layer - self.hidden_layers = hidden_layers - - @property - def is_first_rank(self): - """Return whether the caller is the first process in the group""" - return self.pp_start_layer == 0 - - @property - def is_last_rank(self): - """Return whether the caller is the last process in the group""" - return self.pp_end_layer == self.hidden_layers - - class ParallaxModelRunner(SGLModelRunner): """ Parallax ModelRunner module. @@ -258,250 +202,6 @@ def init_torch_distributed(self): return min_per_gpu_memory -def monkey_patch_init_model_parallel_group( - group_ranks: List[List[int]], - local_rank: int, - backend: str, - use_custom_allreduce: Optional[bool] = None, - use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, - use_mscclpp_allreduce: Optional[bool] = None, - pp_start_layer: int = 0, - pp_end_layer: int = 0, - hidden_layers: int = 0, -) -> SGLGroupCoordinator: - """A monkey patch to replace sglang.srt.distributed.parallel_state.init_model_parallel_group""" - if use_custom_allreduce is None: - use_custom_allreduce = sglang.srt.distributed.parallel_state._ENABLE_CUSTOM_ALL_REDUCE - if use_mscclpp_allreduce is None: - use_mscclpp_allreduce = sglang.srt.distributed.parallel_state._ENABLE_MSCCLPP_ALL_REDUCE - return ParallaxGroupCoordinator( - group_ranks=group_ranks, - local_rank=local_rank, - torch_distributed_backend=backend, - use_pynccl=not is_npu(), - use_pymscclpp=use_mscclpp_allreduce, - use_custom_allreduce=use_custom_allreduce, - use_hpu_communicator=True, - use_xpu_communicator=True, - use_npu_communicator=True, - use_message_queue_broadcaster=use_message_queue_broadcaster, - group_name=group_name, - pp_start_layer=pp_start_layer, - pp_end_layer=pp_end_layer, - hidden_layers=hidden_layers, - ) - - -def monkey_patch_initialize_model_parallel( - tensor_model_parallel_size: int = 1, - expert_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, - duplicate_tp_group: bool = False, - pp_start_layer: int = 0, - pp_end_layer: int = 0, - hidden_layers: int = 0, -) -> None: - """A monkey patch to replace sglang.srt.distributed.parallel_state.initialize_model_parallel""" - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend(get_world_group().device_group) - - if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: - raise RuntimeError( - f"world_size ({world_size}) is not equal to " - f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" - ) - - # Build the tensor model-parallel groups. - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - assert ( - sglang.srt.distributed.parallel_state._TP is None - ), "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) - group_ranks.append(ranks) - - # message queue broadcaster is only used in tensor model parallel group - sglang.srt.distributed.parallel_state._TP = ( - sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=get_bool_env_var( - "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" - ), - group_name="tp", - ) - ) - - if duplicate_tp_group: - global _PDMUX_PREFILL_TP_GROUP - assert ( - _PDMUX_PREFILL_TP_GROUP is None - ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" - _PDMUX_PREFILL_TP_GROUP = sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=get_bool_env_var( - "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" - ), - group_name="pdmux_prefill_tp", - ) - sglang.srt.distributed.parallel_state._TP.pynccl_comm.disabled = False - _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False - - moe_ep_size = expert_model_parallel_size - - moe_tp_size = tensor_model_parallel_size // moe_ep_size - assert ( - sglang.srt.distributed.parallel_state._MOE_EP is None - ), "expert model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - for j in range(moe_tp_size): - st = i * tensor_model_parallel_size + j - en = (i + 1) * tensor_model_parallel_size + j - ranks = list(range(st, en, moe_tp_size)) - group_ranks.append(ranks) - - sglang.srt.distributed.parallel_state._MOE_EP = ( - sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, - group_name="moe_ep", - ) - ) - - assert ( - sglang.srt.distributed.parallel_state._MOE_TP is None - ), "expert model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - for j in range(moe_ep_size): - st = i * tensor_model_parallel_size + j * moe_tp_size - en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size - ranks = list(range(st, en)) - group_ranks.append(ranks) - - sglang.srt.distributed.parallel_state._MOE_TP = ( - sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, - group_name="moe_tp", - ) - ) - - # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - assert ( - sglang.srt.distributed.parallel_state._PP is None - ), "pipeline model parallel group is already initialized" - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce - sglang.srt.distributed.parallel_state._PP = ( - sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, - group_name="pp", - pp_start_layer=pp_start_layer, - pp_end_layer=pp_end_layer, - hidden_layers=hidden_layers, - ) - ) - - -def monkey_patch_make_layers( - num_hidden_layers: int, - layer_fn: LayerFn, - pp_rank: Optional[int] = None, - pp_size: Optional[int] = None, - prefix: str = "", - return_tuple: bool = True, - offloader_kwargs: Dict[str, Any] = {}, -) -> Tuple[int, int, torch.nn.ModuleList]: - """A monkey patch to replace sglang.srt.utils.make_layers""" - # circula imports - from sglang.srt.distributed import get_pp_group - from sglang.srt.layers.utils import PPMissingLayer - from sglang.srt.utils.offloader import get_offloader - - assert not pp_size or num_hidden_layers >= pp_size - start_layer, end_layer = get_pp_group().pp_start_layer, get_pp_group().pp_end_layer - - modules = torch.nn.ModuleList( - [PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)] - + get_offloader().wrap_modules( - ( - layer_fn(idx=idx, prefix=add_prefix(idx, prefix)) - for idx in range(start_layer, end_layer) - ), - **offloader_kwargs, - ) - + [PPMissingLayer(return_tuple=return_tuple) for _ in range(end_layer, num_hidden_layers)] - ) - if pp_rank is None or pp_size is None: - return modules - return modules, start_layer, end_layer - - -## TODO: Move this when sgalang supports qwen3_next pipeline parallelism -def monkey_patch_qwen3_next(): - from parallax.sglang.monkey_patch.qwen3_next_config import ( - apply_qwen3_next_config_monkey_patch, - ) - from parallax.sglang.monkey_patch.qwen3_next_model import ( - apply_qwen3_next_monkey_patch, - ) - - apply_qwen3_next_monkey_patch() - apply_qwen3_next_config_monkey_patch() - - -## TODO: Move this when sgalang supports gpt_oss pipeline parallelism -def monkey_patch_gpt_oss(): - from parallax.sglang.monkey_patch.gpt_oss_model import apply_gpt_oss_monkey_patch - - apply_gpt_oss_monkey_patch() - - -## TODO: Move this when sgalang supports triton backend pipeline parallelism -def monkey_patch_triton_backend_init(): - from parallax.sglang.monkey_patch.triton_backend import ( - apply_triton_backend_init_monkey_patch, - ) - - apply_triton_backend_init_monkey_patch() - - -def monkey_patch_minimax_m2_model(): - from parallax.sglang.monkey_patch.minimax_m2_model import ( - apply_minimax_m2_monkey_patch, - ) - - apply_minimax_m2_monkey_patch() - - -def monkey_patch_glm4_moe_model(): - from parallax.sglang.monkey_patch.glm4_moe_model import apply_glm4_moe_monkey_patch - - apply_glm4_moe_monkey_patch() - - def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", @@ -521,23 +221,6 @@ def form_sgl_server_args( return sgl_server_args -def apply_parallax_monkey_patch(): - """Apply all monkey patch""" - # Function patch - sglang.srt.distributed.parallel_state.init_model_parallel_group = ( - monkey_patch_init_model_parallel_group - ) - sglang.srt.distributed.parallel_state.initialize_model_parallel = ( - monkey_patch_initialize_model_parallel - ) - sglang.srt.utils.make_layers = monkey_patch_make_layers - monkey_patch_qwen3_next() - monkey_patch_gpt_oss() - monkey_patch_triton_backend_init() - monkey_patch_minimax_m2_model() - monkey_patch_glm4_moe_model() - - def initialize_sgl_model_runner( original_model_path: str, start_layer: int, diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py new file mode 100644 index 00000000..1a6715a4 --- /dev/null +++ b/src/parallax/sglang/monkey_patch.py @@ -0,0 +1,27 @@ +from parallax.sglang.monkey_patch_utils.qwen3_next_config import ( + apply_qwen3_next_config_monkey_patch, +) +from parallax.sglang.monkey_patch_utils.qwen3_next_model import ( + apply_qwen3_next_monkey_patch, +) +from parallax.sglang.monkey_patch_utils.gpt_oss_model import apply_gpt_oss_monkey_patch +from parallax.sglang.monkey_patch_utils.minimax_m2_model import ( + apply_minimax_m2_monkey_patch, +) +from parallax.sglang.monkey_patch_utils.glm4_moe_model import apply_glm4_moe_monkey_patch +from parallax.sglang.monkey_patch_utils.triton_backend import ( + apply_triton_backend_init_monkey_patch, +) +from parallax.sglang.monkey_patch_utils.model_parallel import ( + apply_model_parallel_monkey_patch, +) + + +def apply_parallax_monkey_patch(): + apply_qwen3_next_monkey_patch() + apply_qwen3_next_config_monkey_patch() + apply_gpt_oss_monkey_patch() + apply_minimax_m2_monkey_patch() + apply_glm4_moe_monkey_patch() + apply_triton_backend_init_monkey_patch() + apply_model_parallel_monkey_patch() diff --git a/src/parallax/sglang/monkey_patch/glm4_moe_model.py b/src/parallax/sglang/monkey_patch_utils/glm4_moe_model.py similarity index 100% rename from src/parallax/sglang/monkey_patch/glm4_moe_model.py rename to src/parallax/sglang/monkey_patch_utils/glm4_moe_model.py diff --git a/src/parallax/sglang/monkey_patch/gpt_oss_model.py b/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py similarity index 100% rename from src/parallax/sglang/monkey_patch/gpt_oss_model.py rename to src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py diff --git a/src/parallax/sglang/monkey_patch/minimax_m2_model.py b/src/parallax/sglang/monkey_patch_utils/minimax_m2_model.py similarity index 100% rename from src/parallax/sglang/monkey_patch/minimax_m2_model.py rename to src/parallax/sglang/monkey_patch_utils/minimax_m2_model.py diff --git a/src/parallax/sglang/monkey_patch_utils/model_parallel.py b/src/parallax/sglang/monkey_patch_utils/model_parallel.py new file mode 100644 index 00000000..cdf1be26 --- /dev/null +++ b/src/parallax/sglang/monkey_patch_utils/model_parallel.py @@ -0,0 +1,311 @@ +import logging +import os +import random +from typing import Any, Dict, List, Optional, Tuple, Union + +import sglang +import sglang.srt.distributed.parallel_state +import torch +from mlx_lm.utils import get_model_path, load_config +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import ( + get_tp_group, + get_world_group, + init_distributed_environment, + set_custom_all_reduce, + set_mscclpp_all_reduce, +) +from sglang.srt.distributed.parallel_state import ( + GroupCoordinator as SGLGroupCoordinator, +) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_group, + initialize_dp_attention, +) +from sglang.srt.layers.moe import initialize_moe_config +from sglang.srt.model_executor.model_runner import ModelRunner as SGLModelRunner +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + LayerFn, + add_prefix, + cpu_has_amx_support, + get_available_gpu_memory, + get_bool_env_var, + is_npu, + monkey_patch_p2p_access_check, +) +from torch.distributed import Backend + +from parallax.utils.tokenizer_utils import load_tokenizer + +# from parallax.sglang.monkey_patch.model_runner import ModelRunner as SGLModelRunner + +logger = logging.getLogger(__name__) + +_is_cpu_amx_available = cpu_has_amx_support() + + +class ParallaxGroupCoordinator(SGLGroupCoordinator): + """ + Parallax GroupCoordinator module. + pp_start_layer, pp_end_layer, hidden_layers are necessary for decentralized inference. + Also change the definition of first_rank/last_rank. + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool, + use_pymscclpp: bool, + use_custom_allreduce: bool, + use_hpu_communicator: bool, + use_xpu_communicator: bool, + use_npu_communicator: bool, + use_torch_symm_mem: bool = False, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + pp_start_layer: int = 0, + pp_end_layer: int = 0, + hidden_layers: int = 0, + ): + """Add pp_start_layer, pp_end_layer, hidden_layers for decentralized model""" + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + use_pynccl=use_pynccl, + use_pymscclpp=use_pymscclpp, + use_custom_allreduce=use_custom_allreduce, + use_hpu_communicator=use_hpu_communicator, + use_xpu_communicator=use_xpu_communicator, + use_npu_communicator=use_npu_communicator, + use_torch_symm_mem=use_torch_symm_mem, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + self.pp_start_layer = pp_start_layer + self.pp_end_layer = pp_end_layer + self.hidden_layers = hidden_layers + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.pp_start_layer == 0 + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.pp_end_layer == self.hidden_layers + + +def monkey_patch_init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + use_mscclpp_allreduce: Optional[bool] = None, + pp_start_layer: int = 0, + pp_end_layer: int = 0, + hidden_layers: int = 0, +) -> SGLGroupCoordinator: + """A monkey patch to replace sglang.srt.distributed.parallel_state.init_model_parallel_group""" + if use_custom_allreduce is None: + use_custom_allreduce = sglang.srt.distributed.parallel_state._ENABLE_CUSTOM_ALL_REDUCE + if use_mscclpp_allreduce is None: + use_mscclpp_allreduce = sglang.srt.distributed.parallel_state._ENABLE_MSCCLPP_ALL_REDUCE + return ParallaxGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=not is_npu(), + use_pymscclpp=use_mscclpp_allreduce, + use_custom_allreduce=use_custom_allreduce, + use_hpu_communicator=True, + use_xpu_communicator=True, + use_npu_communicator=True, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + pp_start_layer=pp_start_layer, + pp_end_layer=pp_end_layer, + hidden_layers=hidden_layers, + ) + + +def monkey_patch_initialize_model_parallel( + tensor_model_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, + duplicate_tp_group: bool = False, + pp_start_layer: int = 0, + pp_end_layer: int = 0, + hidden_layers: int = 0, +) -> None: + """A monkey patch to replace sglang.srt.distributed.parallel_state.initialize_model_parallel""" + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + assert ( + sglang.srt.distributed.parallel_state._TP is None + ), "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + sglang.srt.distributed.parallel_state._TP = ( + sglang.srt.distributed.parallel_state.init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=get_bool_env_var( + "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" + ), + group_name="tp", + ) + ) + + if duplicate_tp_group: + global _PDMUX_PREFILL_TP_GROUP + assert ( + _PDMUX_PREFILL_TP_GROUP is None + ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" + _PDMUX_PREFILL_TP_GROUP = sglang.srt.distributed.parallel_state.init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=get_bool_env_var( + "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" + ), + group_name="pdmux_prefill_tp", + ) + sglang.srt.distributed.parallel_state._TP.pynccl_comm.disabled = False + _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False + + moe_ep_size = expert_model_parallel_size + + moe_tp_size = tensor_model_parallel_size // moe_ep_size + assert ( + sglang.srt.distributed.parallel_state._MOE_EP is None + ), "expert model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + for j in range(moe_tp_size): + st = i * tensor_model_parallel_size + j + en = (i + 1) * tensor_model_parallel_size + j + ranks = list(range(st, en, moe_tp_size)) + group_ranks.append(ranks) + + sglang.srt.distributed.parallel_state._MOE_EP = ( + sglang.srt.distributed.parallel_state.init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, + group_name="moe_ep", + ) + ) + + assert ( + sglang.srt.distributed.parallel_state._MOE_TP is None + ), "expert model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + for j in range(moe_ep_size): + st = i * tensor_model_parallel_size + j * moe_tp_size + en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size + ranks = list(range(st, en)) + group_ranks.append(ranks) + + sglang.srt.distributed.parallel_state._MOE_TP = ( + sglang.srt.distributed.parallel_state.init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, + group_name="moe_tp", + ) + ) + + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + assert ( + sglang.srt.distributed.parallel_state._PP is None + ), "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + sglang.srt.distributed.parallel_state._PP = ( + sglang.srt.distributed.parallel_state.init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, + group_name="pp", + pp_start_layer=pp_start_layer, + pp_end_layer=pp_end_layer, + hidden_layers=hidden_layers, + ) + ) + + +def monkey_patch_make_layers( + num_hidden_layers: int, + layer_fn: LayerFn, + pp_rank: Optional[int] = None, + pp_size: Optional[int] = None, + prefix: str = "", + return_tuple: bool = True, + offloader_kwargs: Dict[str, Any] = {}, +) -> Tuple[int, int, torch.nn.ModuleList]: + """A monkey patch to replace sglang.srt.utils.make_layers""" + # circula imports + from sglang.srt.distributed import get_pp_group + from sglang.srt.layers.utils import PPMissingLayer + from sglang.srt.utils.offloader import get_offloader + + assert not pp_size or num_hidden_layers >= pp_size + start_layer, end_layer = get_pp_group().pp_start_layer, get_pp_group().pp_end_layer + + modules = torch.nn.ModuleList( + [PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)] + + get_offloader().wrap_modules( + ( + layer_fn(idx=idx, prefix=add_prefix(idx, prefix)) + for idx in range(start_layer, end_layer) + ), + **offloader_kwargs, + ) + + [PPMissingLayer(return_tuple=return_tuple) for _ in range(end_layer, num_hidden_layers)] + ) + if pp_rank is None or pp_size is None: + return modules + return modules, start_layer, end_layer + + +def apply_model_parallel_monkey_patch(): + sglang.srt.distributed.parallel_state.init_model_parallel_group = ( + monkey_patch_init_model_parallel_group + ) + sglang.srt.distributed.parallel_state.initialize_model_parallel = ( + monkey_patch_initialize_model_parallel + ) + sglang.srt.utils.make_layers = monkey_patch_make_layers diff --git a/src/parallax/sglang/monkey_patch/qwen3_next_config.py b/src/parallax/sglang/monkey_patch_utils/qwen3_next_config.py similarity index 100% rename from src/parallax/sglang/monkey_patch/qwen3_next_config.py rename to src/parallax/sglang/monkey_patch_utils/qwen3_next_config.py diff --git a/src/parallax/sglang/monkey_patch/qwen3_next_model.py b/src/parallax/sglang/monkey_patch_utils/qwen3_next_model.py similarity index 100% rename from src/parallax/sglang/monkey_patch/qwen3_next_model.py rename to src/parallax/sglang/monkey_patch_utils/qwen3_next_model.py diff --git a/src/parallax/sglang/monkey_patch/triton_backend.py b/src/parallax/sglang/monkey_patch_utils/triton_backend.py similarity index 100% rename from src/parallax/sglang/monkey_patch/triton_backend.py rename to src/parallax/sglang/monkey_patch_utils/triton_backend.py From ee4e07f3894e25c4272a4fb9b99c22e04c70bc69 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Thu, 30 Oct 2025 16:01:26 +0800 Subject: [PATCH 2/6] pre-commit --- src/parallax/sglang/model_runner.py | 10 +------- src/parallax/sglang/monkey_patch.py | 18 ++++++++------- .../monkey_patch_utils/model_parallel.py | 23 +------------------ 3 files changed, 12 insertions(+), 39 deletions(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index ad3e4130..6771e2eb 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -7,7 +7,6 @@ import logging import os import random -from typing import Any, Dict, List, Optional, Tuple, Union import sglang import sglang.srt.distributed.parallel_state @@ -21,9 +20,6 @@ set_custom_all_reduce, set_mscclpp_all_reduce, ) -from sglang.srt.distributed.parallel_state import ( - GroupCoordinator as SGLGroupCoordinator, -) from sglang.srt.layers.dp_attention import ( get_attention_tp_group, initialize_dp_attention, @@ -32,18 +28,14 @@ from sglang.srt.model_executor.model_runner import ModelRunner as SGLModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( - LayerFn, - add_prefix, cpu_has_amx_support, get_available_gpu_memory, get_bool_env_var, - is_npu, monkey_patch_p2p_access_check, ) -from torch.distributed import Backend -from parallax.utils.tokenizer_utils import load_tokenizer from parallax.sglang.monkey_patch import apply_parallax_monkey_patch +from parallax.utils.tokenizer_utils import load_tokenizer logger = logging.getLogger(__name__) diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py index 1a6715a4..ae03f44f 100644 --- a/src/parallax/sglang/monkey_patch.py +++ b/src/parallax/sglang/monkey_patch.py @@ -1,20 +1,22 @@ +from parallax.sglang.monkey_patch_utils.glm4_moe_model import ( + apply_glm4_moe_monkey_patch, +) +from parallax.sglang.monkey_patch_utils.gpt_oss_model import apply_gpt_oss_monkey_patch +from parallax.sglang.monkey_patch_utils.minimax_m2_model import ( + apply_minimax_m2_monkey_patch, +) +from parallax.sglang.monkey_patch_utils.model_parallel import ( + apply_model_parallel_monkey_patch, +) from parallax.sglang.monkey_patch_utils.qwen3_next_config import ( apply_qwen3_next_config_monkey_patch, ) from parallax.sglang.monkey_patch_utils.qwen3_next_model import ( apply_qwen3_next_monkey_patch, ) -from parallax.sglang.monkey_patch_utils.gpt_oss_model import apply_gpt_oss_monkey_patch -from parallax.sglang.monkey_patch_utils.minimax_m2_model import ( - apply_minimax_m2_monkey_patch, -) -from parallax.sglang.monkey_patch_utils.glm4_moe_model import apply_glm4_moe_monkey_patch from parallax.sglang.monkey_patch_utils.triton_backend import ( apply_triton_backend_init_monkey_patch, ) -from parallax.sglang.monkey_patch_utils.model_parallel import ( - apply_model_parallel_monkey_patch, -) def apply_parallax_monkey_patch(): diff --git a/src/parallax/sglang/monkey_patch_utils/model_parallel.py b/src/parallax/sglang/monkey_patch_utils/model_parallel.py index cdf1be26..28302b96 100644 --- a/src/parallax/sglang/monkey_patch_utils/model_parallel.py +++ b/src/parallax/sglang/monkey_patch_utils/model_parallel.py @@ -1,43 +1,22 @@ import logging -import os -import random from typing import Any, Dict, List, Optional, Tuple, Union import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.utils import get_model_path, load_config -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.distributed import ( - get_tp_group, - get_world_group, - init_distributed_environment, - set_custom_all_reduce, - set_mscclpp_all_reduce, -) +from sglang.srt.distributed import get_world_group from sglang.srt.distributed.parallel_state import ( GroupCoordinator as SGLGroupCoordinator, ) -from sglang.srt.layers.dp_attention import ( - get_attention_tp_group, - initialize_dp_attention, -) -from sglang.srt.layers.moe import initialize_moe_config -from sglang.srt.model_executor.model_runner import ModelRunner as SGLModelRunner -from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( LayerFn, add_prefix, cpu_has_amx_support, - get_available_gpu_memory, get_bool_env_var, is_npu, - monkey_patch_p2p_access_check, ) from torch.distributed import Backend -from parallax.utils.tokenizer_utils import load_tokenizer - # from parallax.sglang.monkey_patch.model_runner import ModelRunner as SGLModelRunner logger = logging.getLogger(__name__) From e139d343a1ff697e1c3971b4c1dc66d61e8da0da Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Thu, 30 Oct 2025 16:04:11 +0800 Subject: [PATCH 3/6] update --- src/parallax/sglang/model_runner.py | 4 ++-- src/parallax/sglang/monkey_patch.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 6771e2eb..2bf8a0f5 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -34,7 +34,7 @@ monkey_patch_p2p_access_check, ) -from parallax.sglang.monkey_patch import apply_parallax_monkey_patch +from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch from parallax.utils.tokenizer_utils import load_tokenizer logger = logging.getLogger(__name__) @@ -229,7 +229,7 @@ def initialize_sgl_model_runner( - config: model config driven by mlx-lm - tokenizer: tokenizer driven by mlx-lm """ - apply_parallax_monkey_patch() + apply_parallax_sglang_monkey_patch() model_path = get_model_path(original_model_path)[0] config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py index ae03f44f..99261f5b 100644 --- a/src/parallax/sglang/monkey_patch.py +++ b/src/parallax/sglang/monkey_patch.py @@ -19,7 +19,7 @@ ) -def apply_parallax_monkey_patch(): +def apply_parallax_sgalng_monkey_patch(): apply_qwen3_next_monkey_patch() apply_qwen3_next_config_monkey_patch() apply_gpt_oss_monkey_patch() From f1148384dec236632bb2b5324d7e5cc630ab919a Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Thu, 30 Oct 2025 16:29:09 +0800 Subject: [PATCH 4/6] update --- src/parallax/sglang/monkey_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py index 99261f5b..d0c35e72 100644 --- a/src/parallax/sglang/monkey_patch.py +++ b/src/parallax/sglang/monkey_patch.py @@ -19,7 +19,7 @@ ) -def apply_parallax_sgalng_monkey_patch(): +def apply_parallax_sglang_monkey_patch(): apply_qwen3_next_monkey_patch() apply_qwen3_next_config_monkey_patch() apply_gpt_oss_monkey_patch() From 61444bfc0a8feb7d26b2608d47366a7570e406c5 Mon Sep 17 00:00:00 2001 From: Alien mac air <2214632589@qq.com> Date: Fri, 31 Oct 2025 16:44:36 +0800 Subject: [PATCH 5/6] add comment --- src/parallax/sglang/monkey_patch.py | 2 ++ .../monkey_patch_utils/glm4_moe_model.py | 12 +++++++++++- .../monkey_patch_utils/gpt_oss_model.py | 9 ++++++++- .../monkey_patch_utils/minimax_m2_model.py | 8 +++++++- .../monkey_patch_utils/model_parallel.py | 19 +++++++++++++++++++ .../monkey_patch_utils/qwen3_next_config.py | 10 ++++++---- .../monkey_patch_utils/qwen3_next_model.py | 1 + .../monkey_patch_utils/triton_backend.py | 5 +++++ 8 files changed, 59 insertions(+), 7 deletions(-) diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py index d0c35e72..3bf53067 100644 --- a/src/parallax/sglang/monkey_patch.py +++ b/src/parallax/sglang/monkey_patch.py @@ -19,6 +19,8 @@ ) +## Here is some patch func for sglang +## Hopefully, when sglang support pipeline parallelism natively, we can remove these patches def apply_parallax_sglang_monkey_patch(): apply_qwen3_next_monkey_patch() apply_qwen3_next_config_monkey_patch() diff --git a/src/parallax/sglang/monkey_patch_utils/glm4_moe_model.py b/src/parallax/sglang/monkey_patch_utils/glm4_moe_model.py index 233e000d..305a86d6 100644 --- a/src/parallax/sglang/monkey_patch_utils/glm4_moe_model.py +++ b/src/parallax/sglang/monkey_patch_utils/glm4_moe_model.py @@ -1,3 +1,5 @@ +## This is a patch file for sglang glm4_moe model to support pipeline parallelism + import logging from typing import Iterable, Optional, Tuple @@ -52,6 +54,9 @@ def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], params_dict = dict(self.named_parameters()) weight_names = [] for name, loaded_weight in weights: + ############################################################################ + ## TODO: remove when sglang code support pipeline parallelism + ## This is a patch code for sgalng if "lm_head" in name: pp_group = getattr(self, "pp_group", None) or get_pp_group() if not pp_group.is_last_rank: @@ -64,6 +69,8 @@ def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) ): continue + ## End of patch + ############################################################################ weight_names.append(name) if not is_nextn: @@ -182,12 +189,15 @@ def pp_forward( if isinstance(hidden_states, PPProxyTensors): return hidden_states - + ################################################################################ + ## Patch for PP: only last PP rank compute logits pp_group = getattr(self, "pp_group", None) or get_pp_group() if pp_group.is_last_rank: return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch) else: return hidden_states + ## End of patch + ################################################################################ glm4_moe_module.Glm4MoeForCausalLM.forward = pp_forward glm4_moe_module.Glm4MoeForCausalLM.load_weights = monkey_patch_load_weights diff --git a/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py b/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py index b9d7417d..44149363 100644 --- a/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py +++ b/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py @@ -1,5 +1,7 @@ -import math +## This is a patch file for sglang GPT-OSS model to support loading mxFP4 MoE experts weights +import math +import torch from sglang.srt.distributed import ( get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size, @@ -43,6 +45,9 @@ def _parallax_load_mxfp4_experts_weights(self, weights): moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts for name, weight in weights: + ############################################################################ + ## TODO: remove when sglang code support pipeline parallelism + ## This is a patch code for sgalng layer_id = get_layer_id(name) if ( layer_id is not None @@ -50,6 +55,8 @@ def _parallax_load_mxfp4_experts_weights(self, weights): and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) ): continue + ## End of patch + ############################################################################ weight = weight.cuda() if "gate_up_proj_blocks" in name: diff --git a/src/parallax/sglang/monkey_patch_utils/minimax_m2_model.py b/src/parallax/sglang/monkey_patch_utils/minimax_m2_model.py index f14e59bf..ead6c3a0 100644 --- a/src/parallax/sglang/monkey_patch_utils/minimax_m2_model.py +++ b/src/parallax/sglang/monkey_patch_utils/minimax_m2_model.py @@ -1,3 +1,5 @@ +## This is a patch file for sglang MiniMax M2 model to support pipeline parallelism + import logging from typing import Iterable, Optional, Set, Tuple @@ -139,12 +141,16 @@ def pp_forward( if isinstance(hidden_states, PPProxyTensors): return hidden_states - + ########################################################################## + ## TODO: remove when sglang code support pipeline parallelism + ## This is a patch code for sgalng pp_group = getattr(self, "pp_group", None) or get_pp_group() if pp_group.is_last_rank: return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch) else: return hidden_states + ## End of patch + ########################################################################## m2_module.MiniMaxM2ForCausalLM.__init__ = pp_init m2_module.MiniMaxM2ForCausalLM.forward = pp_forward diff --git a/src/parallax/sglang/monkey_patch_utils/model_parallel.py b/src/parallax/sglang/monkey_patch_utils/model_parallel.py index 28302b96..ee631c3b 100644 --- a/src/parallax/sglang/monkey_patch_utils/model_parallel.py +++ b/src/parallax/sglang/monkey_patch_utils/model_parallel.py @@ -1,3 +1,22 @@ +"""Parallax model-parallel monkey patches for sglang. + +Summary: +- ParallaxGroupCoordinator (subclasses sglang.srt.distributed.parallel_state.GroupCoordinator): + adds pp_start_layer, pp_end_layer, hidden_layers and redefines is_first_rank/is_last_rank to use + layer ranges. +- monkey_patch_init_model_parallel_group: replaces + sglang.srt.distributed.parallel_state.init_model_parallel_group to return ParallaxGroupCoordinator. +- monkey_patch_initialize_model_parallel: replaces + sglang.srt.distributed.parallel_state.initialize_model_parallel and passes PP layer bounds when + creating pipeline-parallel groups. +- monkey_patch_make_layers: replaces sglang.srt.utils.make_layers; uses + get_pp_group().pp_start_layer/end_layer to instantiate local layers and PPMissingLayer placeholders + for non-local layers. + +These are minimal, reversible patches to support decentralized per-layer pipeline parallelism. Remove +when upstream sglang provides native support. +""" + import logging from typing import Any, Dict, List, Optional, Tuple, Union diff --git a/src/parallax/sglang/monkey_patch_utils/qwen3_next_config.py b/src/parallax/sglang/monkey_patch_utils/qwen3_next_config.py index 489add3f..76f8a2e1 100644 --- a/src/parallax/sglang/monkey_patch_utils/qwen3_next_config.py +++ b/src/parallax/sglang/monkey_patch_utils/qwen3_next_config.py @@ -15,6 +15,7 @@ class HybridLayerType(enum.Enum): mamba2 = "mamba" +## overwirite due to pipeline parallelism @property def monkey_patch_linear_layer_ids(self): """Return linear-attention layer ids restricted to the PP slice. @@ -29,11 +30,12 @@ def monkey_patch_linear_layer_ids(self): and i >= self.start_layer and i < self.end_layer ] - # If no matching layer id, return at least [-1] - # just for pp + ## If no matching layer id, return at least [-1] + ## It is for memory pool calcuate tokens return lst if lst else [-1] +## overwirite due to pipeline parallelism @property def monkey_patch_full_attention_layer_ids(self): """Return full-attention layer ids restricted to the PP slice. @@ -48,8 +50,8 @@ def monkey_patch_full_attention_layer_ids(self): and i >= self.start_layer and i < self.end_layer ] - # If no matching layer id, return at least [-1] - # just for pp + ## If no matching layer id, return at least [-1] + ## It is for memory pool calcuate tokens return lst if lst else [-1] diff --git a/src/parallax/sglang/monkey_patch_utils/qwen3_next_model.py b/src/parallax/sglang/monkey_patch_utils/qwen3_next_model.py index 04872a49..bd82312e 100644 --- a/src/parallax/sglang/monkey_patch_utils/qwen3_next_model.py +++ b/src/parallax/sglang/monkey_patch_utils/qwen3_next_model.py @@ -11,6 +11,7 @@ # ---- Minimal method-level monkey patch to reuse sglang source ---- +# Due to Qwen3NextModel not support pipeline parallelism (PP) natively def apply_qwen3_next_monkey_patch(): """Apply minimal monkey patches to sglang's qwen3_next to support PP without copying code. diff --git a/src/parallax/sglang/monkey_patch_utils/triton_backend.py b/src/parallax/sglang/monkey_patch_utils/triton_backend.py index 02cd81df..cfd08074 100644 --- a/src/parallax/sglang/monkey_patch_utils/triton_backend.py +++ b/src/parallax/sglang/monkey_patch_utils/triton_backend.py @@ -39,9 +39,14 @@ def parallax_triton_backend_init( # For hybrid linear models, layer_id = 0 may not be full attention self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() else: + + ################################################################################ + ## Patch for PP: get pp_start_layer self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( model_runner.pp_start_layer ).shape[-1] + ## End of patch + ################################################################################ self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.device_core_count = get_device_core_count(model_runner.gpu_id) From a009ed0e090c92a474dff7db08f2d07f6f7f6f73 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 3 Nov 2025 15:06:15 +0800 Subject: [PATCH 6/6] update --- src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py b/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py index 44149363..acc26458 100644 --- a/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py +++ b/src/parallax/sglang/monkey_patch_utils/gpt_oss_model.py @@ -1,6 +1,7 @@ ## This is a patch file for sglang GPT-OSS model to support loading mxFP4 MoE experts weights import math + import torch from sglang.srt.distributed import ( get_moe_expert_parallel_rank,