From 1d1091a273f7aa9327961c221884d0a94e4201fe Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Thu, 30 Oct 2025 12:39:53 +0800 Subject: [PATCH 1/9] initial commit for supporting TP --- src/parallax/launch.py | 73 ++++++--- src/parallax/server/executor.py | 226 ++++++++++++++++------------ src/parallax/server/server_args.py | 3 + src/parallax/sglang/model_runner.py | 8 +- src/parallax/utils/utils.py | 18 +++ 5 files changed, 207 insertions(+), 121 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index e4443c23..dccd802f 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -23,7 +23,7 @@ from parallax.server.executor import Executor from parallax.server.http_server import launch_http_server from parallax.server.server_args import parse_args -from parallax.utils.utils import get_current_device +from parallax.utils.utils import get_current_device, load_model_config_only from parallax_utils.ascii_anime import display_parallax_join from parallax_utils.logging_config import get_logger, set_log_level @@ -44,12 +44,23 @@ "zai-org/GLM-4.6": "mlx-community/GLM-4.6-4bit", } + +def run_executor_process(args): + """Run executor as a subprocess""" + try: + executor = Executor.create_from_args(args) + executor.run_loop() + except Exception as e: + logger.exception(e) + finally: + executor.shutdown() + + if __name__ == "__main__": multiprocessing.set_start_method("spawn", force=True) gradient_server = None http_server_process = None - executor = None try: args = parse_args() set_log_level(args.log_level) @@ -70,7 +81,9 @@ if mlx_model_repo is not None: args.model_path = mlx_model_repo logger.debug(f"Replace mlx model path: {mlx_model_repo}") + if args.scheduler_addr is None: + # Launch without scheduler if args.log_level != "DEBUG": display_parallax_join(args.model_path) check_latest_release() @@ -78,14 +91,15 @@ # only launch http server on head node if args.start_layer == 0: http_server_process = launch_http_server(args) - executor = Executor.create_from_args(args) + + config = load_model_config_only(args.model_path) launch_p2p_server( initial_peers=args.initial_peers, scheduler_addr=args.scheduler_addr, relay_servers=args.relay_servers, pp_start_layer=args.start_layer, pp_end_layer=args.end_layer, - hidden_layers=executor.config.get("num_hidden_layers"), + hidden_layers=config.get("num_hidden_layers"), tcp_port=args.tcp_port, udp_port=args.udp_port, dht_prefix=args.dht_prefix, @@ -99,6 +113,7 @@ max_sequence_length=args.max_sequence_length, ) else: + # Join scheduler gradient_server = launch_p2p_server( initial_peers=args.initial_peers, scheduler_addr=args.scheduler_addr, @@ -139,35 +154,45 @@ # only launch http server on head node if args.start_layer == 0: http_server_process = launch_http_server(args) - executor = Executor.create_from_args(args) + + executor_process_pool = [] + tp_rank_range = range(args.tp_size) + for tp_rank in tp_rank_range: + args.tp_rank = tp_rank + # executor = Executor.create_from_args(args) + proc = multiprocessing.Process( + target=run_executor_process, + args=args, + ) + proc.start() + executor_process_pool.append(proc) if gradient_server is not None: gradient_server.status = ServerState.READY - executor.run_loop() except KeyboardInterrupt: logger.debug("Received interrupt signal, shutting down...") except Exception as e: logger.exception(e) finally: - t = None - if http_server_process is not None: + thread_pool = [] - def terminate_http_server_process(process): - logger.debug("Terminating HTTP server process...") - try: - process.kill() - process.join() - except Exception as e: - logger.error(f"Failed to terminate HTTP server process: {e}") - - if http_server_process is not None: - t = threading.Thread( - target=terminate_http_server_process, args=(http_server_process,) - ) - t.start() + def terminate_subprocess(process): + logger.debug("Terminating subprocess...") + try: + process.kill() + process.join() + except Exception as e: + logger.error(f"Failed to terminate subprocess: {e}") + + if http_server_process is not None: + t = threading.Thread(target=terminate_subprocess, args=(http_server_process,)) + thread_pool.append(t) + t.start() if gradient_server is not None: gradient_server.shutdown() - if executor is not None: - executor.shutdown() - if t is not None: + for executor_process in executor_process_pool: + t = threading.Thread(target=terminate_subprocess, args=(executor_process,)) + thread_pool.append(t) + t.start() + for t in thread_pool: t.join() diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 4b6b6d6c..7336d140 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -96,6 +96,9 @@ def __init__( # GPU/SGLang Specialized Configs attention_backend: Optional[str] = "torch_native", moe_runner_backend: Optional[str] = "auto", + # Tensor Parallel Configs + tp_rank: Optional[int] = 0, + tp_size: Optional[int] = 1, ): # Backend self.device = get_current_device() @@ -118,10 +121,14 @@ def __init__( attention_backend, kv_block_size, moe_runner_backend, + tp_rank, + tp_size, ) logger.debug( f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" ) + self.tp_group = self.model_runner.tp_group() + self.tp_cpu_group = self.tp_group.cpu_group # SGL KV Cache Manager is already initialized in ScheduleBatch # TODO: Replace ScheduleBatch to Parallax inflight batch self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) @@ -147,6 +154,7 @@ def __init__( self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == self.config.get("num_hidden_layers") self.num_shard_layers = end_layer - start_layer + self.tp_rank = tp_rank # Metrics throttling for per-layer latency updates self.layer_latency_update_every = int(max(1, layer_latency_update_every)) @@ -280,83 +288,106 @@ def create_from_args(cls, args: argparse.Namespace): """Create executor from command line arguments.""" return cls(**create_executor_config(args)) + def _tensor_parallel_broadcast_byobj(self, broadcast_obj): + """Wrapper for broadcast pyobject in TP group""" + from sglang.srt.utils import broadcast_pyobj + + broadcast_pyobj( + broadcast_obj, + self.tp_group.rank, + self.tp_cpu_group, + src=self.tp_group.ranks[0], + ) + def recv_requests_from_http(self) -> List[Request]: """Receives requests from http frontend""" - recv_reqs = [] - while True: - try: - raw_request = self.recv_from_ipc_socket.recv_pyobj(zmq.NOBLOCK) - - # Check if this is an abort request - if isinstance(raw_request, dict) and raw_request.get("type") == "abort": - logger.debug( - f"Received abort request from HTTP for request ID: {raw_request.get('rid')}" - ) - self.scheduler.cancel_request(raw_request.get("rid")) - else: - # Normal request processing - do tokenization and form InitialRequest - req = self._handle_raw_request(raw_request) - recv_reqs.append(req) - except zmq.ZMQError: - break - except Exception as e: - logger.exception(f"Error receiving http request: {e}") + if self.tp_rank == 0: + recv_reqs = [] + while True: + try: + raw_request = self.recv_from_ipc_socket.recv_pyobj(zmq.NOBLOCK) + + # Check if this is an abort request + if isinstance(raw_request, dict) and raw_request.get("type") == "abort": + logger.debug( + f"Received abort request from HTTP for request ID: {raw_request.get('rid')}" + ) + self.scheduler.cancel_request(raw_request.get("rid")) + else: + # Normal request processing - do tokenization and form InitialRequest + req = self._handle_raw_request(raw_request) + recv_reqs.append(req) + except zmq.ZMQError: + break + except Exception as e: + logger.exception(f"Error receiving http request: {e}") + else: + recv_reqs = None + if self.tp_size > 1: + self._tensor_parallel_broadcast_byobj(recv_reqs) if recv_reqs: logger.debug(f"Received {len(recv_reqs)} HTTP requests") return recv_reqs def recv_requests_from_peer(self) -> List[Request]: """Receives requests from the RPC server.""" - recv_reqs = [] - while True: - try: - recv_req = self.recv_from_peer_socket.recv_multipart(zmq.NOBLOCK) - assert len(recv_req) == 2, f"Received invalid request: {recv_req}" - if recv_req[0] == b"forward": - # Create a new ForwardRequest instance and parse from bytes - forward_request = forward_pb2.ForwardRequest() - forward_request.ParseFromString(recv_req[1]) - recv_req = proto_to_request(forward_request, self.device) - - # Convert hidden_states dtype if necessary - if recv_req is not None and len(recv_req) > 0: - for req in recv_req: - if req.hidden_states is not None: - if req.hidden_states.dtype != self.dtype: - logger.debug( - f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" - ) - if self.device == "cuda": - req.hidden_states = req.hidden_states.to(self.dtype) - elif self.device == "mlx": - req.hidden_states = req.hidden_states.astype(self.dtype) - else: - raise ValueError(f"Unsupported device type: {self.device}") - - # Move current position for first peer - if self.is_first_peer: - for req in recv_req: - req.current_position += 1 - recv_reqs.extend(recv_req) - elif recv_req[0] == b"abort": - abort_request = forward_pb2.AbortRequest() - abort_request.ParseFromString(recv_req[1]) - recv_req = proto_to_abort_request(abort_request) - recv_reqs.extend(recv_req) - else: - raise ValueError(f"Unknown request type: {recv_req[0]}") - # First peer is responsible for tokenization - # if self.is_first_peer and isinstance(recv_req, InitialRequest): - # recv_req.input_ids = self.tokenizer.encode(recv_req.prompt) - # recv_req.prompt_len = len(recv_req.input_ids) - # recv_req.max_total_length = min( - # recv_req.max_total_length, recv_req.prompt_len + recv_req.max_new_tokens - # ) - - except zmq.ZMQError: - break - except Exception as e: - logger.exception(f"Error receiving or deserializing request: {e}") + if self.tp_rank == 0: + recv_reqs = [] + while True: + try: + recv_req = self.recv_from_peer_socket.recv_multipart(zmq.NOBLOCK) + assert len(recv_req) == 2, f"Received invalid request: {recv_req}" + if recv_req[0] == b"forward": + # Create a new ForwardRequest instance and parse from bytes + forward_request = forward_pb2.ForwardRequest() + forward_request.ParseFromString(recv_req[1]) + recv_req = proto_to_request(forward_request, self.device) + + # Convert hidden_states dtype if necessary + if recv_req is not None and len(recv_req) > 0: + for req in recv_req: + if req.hidden_states is not None: + if req.hidden_states.dtype != self.dtype: + logger.debug( + f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" + ) + if self.device == "cuda": + req.hidden_states = req.hidden_states.to(self.dtype) + elif self.device == "mlx": + req.hidden_states = req.hidden_states.astype(self.dtype) + else: + raise ValueError( + f"Unsupported device type: {self.device}" + ) + + # Move current position for first peer + if self.is_first_peer: + for req in recv_req: + req.current_position += 1 + recv_reqs.extend(recv_req) + elif recv_req[0] == b"abort": + abort_request = forward_pb2.AbortRequest() + abort_request.ParseFromString(recv_req[1]) + recv_req = proto_to_abort_request(abort_request) + recv_reqs.extend(recv_req) + else: + raise ValueError(f"Unknown request type: {recv_req[0]}") + # First peer is responsible for tokenization + # if self.is_first_peer and isinstance(recv_req, InitialRequest): + # recv_req.input_ids = self.tokenizer.encode(recv_req.prompt) + # recv_req.prompt_len = len(recv_req.input_ids) + # recv_req.max_total_length = min( + # recv_req.max_total_length, recv_req.prompt_len + recv_req.max_new_tokens + # ) + + except zmq.ZMQError: + break + except Exception as e: + logger.exception(f"Error receiving or deserializing request: {e}") + else: + recv_reqs = None + if self.tp_size > 1: + self._tensor_parallel_broadcast_byobj(recv_reqs) if recv_reqs: logger.debug(f"Received {len(recv_reqs)} peer requests") return recv_reqs @@ -963,33 +994,38 @@ def _prepare_next_batch_requests( self, requests: List[Request], hidden_states: Any, lengths: Any ) -> List[Request]: """Prepares a batch of requests for the next stage of the pipeline.""" - batched_requests = [] - pre_length = 0 - for i, src_request in enumerate(requests): - if self.is_last_peer: - # Last peer gets a 1D array of token IDs - hidden_state_for_req = hidden_states[i : i + 1] - else: - # Other peers get a 3D array of hidden states - if src_request.is_prefill: - true_length = int(lengths[i]) - if hidden_states.ndim == 3: - hidden_state_for_req = hidden_states[i, :true_length, :] - else: - hidden_state_for_req = hidden_states[ - pre_length : pre_length + true_length, : - ] - pre_length += true_length + if self.tp_rank == 0: + batched_requests = [] + pre_length = 0 + for i, src_request in enumerate(requests): + if self.is_last_peer: + # Last peer gets a 1D array of token IDs + hidden_state_for_req = hidden_states[i : i + 1] else: - if hidden_states.ndim == 3: - hidden_state_for_req = hidden_states[i, :, :] + # Other peers get a 3D array of hidden states + if src_request.is_prefill: + true_length = int(lengths[i]) + if hidden_states.ndim == 3: + hidden_state_for_req = hidden_states[i, :true_length, :] + else: + hidden_state_for_req = hidden_states[ + pre_length : pre_length + true_length, : + ] + pre_length += true_length else: - hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] - pre_length += 1 + if hidden_states.ndim == 3: + hidden_state_for_req = hidden_states[i, :, :] + else: + hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] + pre_length += 1 - next_req = self._prepare_next_single_request(src_request, hidden_state_for_req) - batched_requests.append(next_req) + next_req = self._prepare_next_single_request(src_request, hidden_state_for_req) + batched_requests.append(next_req) + else: + batched_requests = None + if self.tp_size > 1: + self._tensor_parallel_broadcast_byobj(batched_requests) return batched_requests def _process_batch_cuda( @@ -1133,7 +1169,7 @@ def run_loop(self): self._handle_input_requests(incoming_requests) # 3. Send finished batch to next peer - if len(self.finished_batch) > 0 and self.is_first_peer: + if len(self.finished_batch) > 0 and self.is_first_peer and self.tp_rank == 0: self.send_to_peer_socket.send_multipart( [b"abort", abort_request_to_proto(self.finished_batch).SerializeToString()] ) @@ -1189,7 +1225,7 @@ def run_loop(self): if self.is_last_peer and self.is_first_peer: # Single node: handle locally self._handle_input_requests(next_batch) - else: + elif self.tp_rank == 0: # Send output to next peer self.send_to_peer_socket.send_multipart( [ @@ -1251,5 +1287,7 @@ def create_executor_config(args: argparse.Namespace): "executor_output_ipc_addr": args.executor_output_ipc, "attention_backend": args.attention_backend, "moe_runner_backend": args.moe_runner_backend, + "tp_rank": args.tp_rank, + "tp_size": args.tp_size, } return config diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 5b1429e0..5c39f91d 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -150,6 +150,9 @@ def parse_args() -> argparse.Namespace: help="Choose the GPU moe kernels", ) + # Tensor parallel configuration + parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size") + # Logging and debugging parser.add_argument( "--log-level", diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 0692dd34..cabdab7a 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -546,6 +546,8 @@ def initialize_sgl_model_runner( attention_backend: str, kv_block_size: int, moe_runner_backend: str, + tp_rank: int, + tp_size: int, ): """ Creates a SGL ModelRunner object. @@ -603,9 +605,9 @@ def initialize_sgl_model_runner( model_runner = ParallaxModelRunner( model_config=model_config, mem_fraction_static=kv_cache_memory_fraction, - gpu_id=0, - tp_rank=0, - tp_size=1, + gpu_id=tp_rank, # Currently reuse tp_rank to only support TP. + tp_rank=tp_rank, + tp_size=tp_size, pp_rank=0, pp_size=1, moe_ep_rank=0, diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index c2cfff1f..7980fdbe 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -1,5 +1,7 @@ """Utility functions.""" +import json +from pathlib import Path from typing import List import mlx.core as mx @@ -7,6 +9,7 @@ import psutil import torch import zmq +from huggingface_hub import hf_hub_download def is_cuda_available(): @@ -266,3 +269,18 @@ def combine_padding_and_causal_masks( padding_mask_float = (padding_mask - 1) * inf_value padding_mask_float = padding_mask_float.astype(dtype) return causal_mask + padding_mask_float + + +def load_model_config_only(name: str) -> dict: + local_path = Path(name) + if local_path.exists(): + config_path = local_path / "config.json" + with open(config_path, "r") as f: + return json.load(f) + + config_file = hf_hub_download(repo_id=name, filename="config.json") + with open(config_file, "r") as f: + return json.load(f) + + config = _load_config_only(model_name) + return config From af9a06b762011d3cd1e14f8978671aa540c424d3 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Thu, 30 Oct 2025 04:55:22 +0000 Subject: [PATCH 2/9] update --- src/parallax/server/executor.py | 1 + tests/test_server_args.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 7336d140..d61f7e91 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -154,6 +154,7 @@ def __init__( self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == self.config.get("num_hidden_layers") self.num_shard_layers = end_layer - start_layer + self.tp_size = tp_size self.tp_rank = tp_rank # Metrics throttling for per-layer latency updates diff --git a/tests/test_server_args.py b/tests/test_server_args.py index deaf5f42..3343163b 100644 --- a/tests/test_server_args.py +++ b/tests/test_server_args.py @@ -88,6 +88,8 @@ def test_create_config(self): executor_output_ipc="///ipc/2", attention_backend="torch_native", moe_runner_backend="auto", + tp_rank=0, + tp_size=1, ) config = create_executor_config(args) From 4e1576aebbc92ea2df83813fdbc3c83aa6da4831 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Tue, 4 Nov 2025 11:58:55 +0000 Subject: [PATCH 3/9] update --- src/parallax/launch.py | 23 ++++++++++------------- src/parallax/server/executor.py | 2 +- src/parallax/utils/utils.py | 18 ++++-------------- 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index dccd802f..d25faff6 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -13,6 +13,7 @@ --end-layer 28 """ +import argparse import multiprocessing import os import tempfile @@ -23,7 +24,7 @@ from parallax.server.executor import Executor from parallax.server.http_server import launch_http_server from parallax.server.server_args import parse_args -from parallax.utils.utils import get_current_device, load_model_config_only +from parallax.utils.utils import get_current_device, fetch_model_from_hf from parallax_utils.ascii_anime import display_parallax_join from parallax_utils.logging_config import get_logger, set_log_level @@ -47,13 +48,8 @@ def run_executor_process(args): """Run executor as a subprocess""" - try: - executor = Executor.create_from_args(args) - executor.run_loop() - except Exception as e: - logger.exception(e) - finally: - executor.shutdown() + executor = Executor.create_from_args(args) + executor.run_loop() if __name__ == "__main__": @@ -61,6 +57,7 @@ def run_executor_process(args): gradient_server = None http_server_process = None + executor_process_pool = [] try: args = parse_args() set_log_level(args.log_level) @@ -88,11 +85,11 @@ def run_executor_process(args): display_parallax_join(args.model_path) check_latest_release() + config = fetch_model_from_hf(args.model_path) # only launch http server on head node if args.start_layer == 0: http_server_process = launch_http_server(args) - config = load_model_config_only(args.model_path) launch_p2p_server( initial_peers=args.initial_peers, scheduler_addr=args.scheduler_addr, @@ -151,18 +148,18 @@ def run_executor_process(args): display_parallax_join(args.model_path) check_latest_release() + fetch_model_from_hf(args.model_path) # only launch http server on head node if args.start_layer == 0: http_server_process = launch_http_server(args) - executor_process_pool = [] tp_rank_range = range(args.tp_size) for tp_rank in tp_rank_range: - args.tp_rank = tp_rank - # executor = Executor.create_from_args(args) + args_copy = argparse.Namespace(**vars(args)) + args_copy.tp_rank = tp_rank proc = multiprocessing.Process( target=run_executor_process, - args=args, + args=(args_copy,), ) proc.start() executor_process_pool.append(proc) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index d61f7e91..265add86 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -127,7 +127,7 @@ def __init__( logger.debug( f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" ) - self.tp_group = self.model_runner.tp_group() + self.tp_group = self.model_runner.tp_group self.tp_cpu_group = self.tp_group.cpu_group # SGL KV Cache Manager is already initialized in ScheduleBatch # TODO: Replace ScheduleBatch to Parallax inflight batch diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index 7980fdbe..451e0c54 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -9,8 +9,7 @@ import psutil import torch import zmq -from huggingface_hub import hf_hub_download - +from mlx_lm.utils import get_model_path, load_config def is_cuda_available(): """Check backend supports cuda""" @@ -271,16 +270,7 @@ def combine_padding_and_causal_masks( return causal_mask + padding_mask_float -def load_model_config_only(name: str) -> dict: - local_path = Path(name) - if local_path.exists(): - config_path = local_path / "config.json" - with open(config_path, "r") as f: - return json.load(f) - - config_file = hf_hub_download(repo_id=name, filename="config.json") - with open(config_file, "r") as f: - return json.load(f) - - config = _load_config_only(model_name) +def fetch_model_from_hf(name: str): + model_path = get_model_path(name)[0] + config = load_config(model_path) return config From 78e2108c6b87832e92c66fcb294d4735695a4ee5 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Fri, 7 Nov 2025 03:48:16 +0000 Subject: [PATCH 4/9] update --- src/parallax/server/executor.py | 7 +++++-- src/parallax/server/server_args.py | 4 ++++ src/parallax/sglang/model_runner.py | 5 ++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 265add86..069dcbba 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -99,6 +99,7 @@ def __init__( # Tensor Parallel Configs tp_rank: Optional[int] = 0, tp_size: Optional[int] = 1, + nccl_port: Optional[int] = 4001, ): # Backend self.device = get_current_device() @@ -123,6 +124,7 @@ def __init__( moe_runner_backend, tp_rank, tp_size, + nccl_port, ) logger.debug( f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" @@ -823,10 +825,10 @@ def _handle_cuda_input_requests(self, requests: List[Request]): def _handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" - if len(requests) > 0: - logger.debug(f"Handling {len(requests)} requests.") if not requests: return + if len(requests) > 0: + logger.debug(f"Handling {len(requests)} requests.") if self.device == "cuda": self._handle_cuda_input_requests(requests) @@ -1290,5 +1292,6 @@ def create_executor_config(args: argparse.Namespace): "moe_runner_backend": args.moe_runner_backend, "tp_rank": args.tp_rank, "tp_size": args.tp_size, + "nccl_port": args.nccl_port, } return config diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 5c39f91d..2f47e18a 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -153,6 +153,10 @@ def parse_args() -> argparse.Namespace: # Tensor parallel configuration parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size") + parser.add_argument( + "--nccl-port", type=int, default=4001, help="The port for NCCL distributed environment setup." + ) + # Logging and debugging parser.add_argument( "--log-level", diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index cabdab7a..0429f887 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -505,6 +505,7 @@ def monkey_patch_glm4_moe_model(): def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", + tp_size: int = 1, attention_backend: str = "flashinfer", kv_block_size: int = 64, moe_runner_backend="auto", @@ -517,6 +518,7 @@ def form_sgl_server_args( page_size=kv_block_size, mem_fraction_static=0.85, moe_runner_backend=moe_runner_backend, + tp_size=tp_size, ) return sgl_server_args @@ -548,6 +550,7 @@ def initialize_sgl_model_runner( moe_runner_backend: str, tp_rank: int, tp_size: int, + nccl_port: int, ): """ Creates a SGL ModelRunner object. @@ -561,7 +564,6 @@ def initialize_sgl_model_runner( config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") - nccl_port = random.randint(4000, 5000) # Handling mxfp4 arguments quant_method = config.get("quant_method", None) @@ -580,6 +582,7 @@ def initialize_sgl_model_runner( server_args = form_sgl_server_args( original_model_path, dtype, + tp_size, attention_backend, kv_block_size, moe_runner_backend, From 390a7437534f6825919b1b275736b4f28326c7c7 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Fri, 7 Nov 2025 04:04:58 +0000 Subject: [PATCH 5/9] update --- tests/test_server_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_server_args.py b/tests/test_server_args.py index 3343163b..93601403 100644 --- a/tests/test_server_args.py +++ b/tests/test_server_args.py @@ -90,6 +90,7 @@ def test_create_config(self): moe_runner_backend="auto", tp_rank=0, tp_size=1, + nccl_port=4001, ) config = create_executor_config(args) From 768b8ffdc61bf179c582457054c4fe6f98b48c79 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Fri, 7 Nov 2025 04:27:07 +0000 Subject: [PATCH 6/9] update --- src/parallax/launch.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index d25faff6..7936dab5 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -57,7 +57,7 @@ def run_executor_process(args): gradient_server = None http_server_process = None - executor_process_pool = [] + executor_procs = [] try: args = parse_args() set_log_level(args.log_level) @@ -162,34 +162,34 @@ def run_executor_process(args): args=(args_copy,), ) proc.start() - executor_process_pool.append(proc) + executor_procs.append(proc) if gradient_server is not None: gradient_server.status = ServerState.READY + for executor_process in executor_procs: + executor_process.join() except KeyboardInterrupt: logger.debug("Received interrupt signal, shutting down...") except Exception as e: logger.exception(e) finally: - thread_pool = [] - - def terminate_subprocess(process): - logger.debug("Terminating subprocess...") - try: - process.kill() - process.join() - except Exception as e: - logger.error(f"Failed to terminate subprocess: {e}") - + t = None if http_server_process is not None: - t = threading.Thread(target=terminate_subprocess, args=(http_server_process,)) - thread_pool.append(t) - t.start() + + def terminate_http_server_process(process): + logger.debug("Terminating HTTP server process...") + try: + process.kill() + process.join() + except Exception as e: + logger.error(f"Failed to terminate HTTP server process: {e}") + + if http_server_process is not None: + t = threading.Thread( + target=terminate_http_server_process, args=(http_server_process,) + ) + t.start() if gradient_server is not None: gradient_server.shutdown() - for executor_process in executor_process_pool: - t = threading.Thread(target=terminate_subprocess, args=(executor_process,)) - thread_pool.append(t) - t.start() - for t in thread_pool: + if t is not None: t.join() From b7c02a0198e21f2fdf0b9b3a054c9d8f280c9728 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Fri, 7 Nov 2025 11:00:56 +0000 Subject: [PATCH 7/9] update --- src/parallax/launch.py | 48 +++++++++++++++++----------- src/parallax/server/executor.py | 55 +++++++++++++++++---------------- 2 files changed, 59 insertions(+), 44 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 7936dab5..5a05baf2 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -48,9 +48,24 @@ def run_executor_process(args): """Run executor as a subprocess""" - executor = Executor.create_from_args(args) - executor.run_loop() + try: + executor = Executor.create_from_args(args) + executor.run_loop() + except KeyboardInterrupt: + logger.debug("Received interrupt signal, shutting down...") + except Exception as e: + logger.exception(e) + finally: + executor.shutdown() +def terminate_subprocess(process): + """Kill a subprocess""" + logger.debug("Terminating subprocess...") + try: + process.kill() + process.join() + except Exception as e: + logger.error(f"Failed to terminate subprocess: {e}") if __name__ == "__main__": multiprocessing.set_start_method("spawn", force=True) @@ -173,23 +188,20 @@ def run_executor_process(args): except Exception as e: logger.exception(e) finally: - t = None + thread_pool = [] + for executor_proc in executor_procs: + t = threading.Thread( + target=terminate_subprocess, args=(executor_proc,) + ) + t.start() + thread_pool.append(t) if http_server_process is not None: - - def terminate_http_server_process(process): - logger.debug("Terminating HTTP server process...") - try: - process.kill() - process.join() - except Exception as e: - logger.error(f"Failed to terminate HTTP server process: {e}") - - if http_server_process is not None: - t = threading.Thread( - target=terminate_http_server_process, args=(http_server_process,) - ) - t.start() + t = threading.Thread( + target=terminate_subprocess, args=(http_server_process,) + ) + t.start() + thread_pool.append(t) if gradient_server is not None: gradient_server.shutdown() - if t is not None: + for t in thread_pool: t.join() diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 069dcbba..898e2a0e 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -295,12 +295,13 @@ def _tensor_parallel_broadcast_byobj(self, broadcast_obj): """Wrapper for broadcast pyobject in TP group""" from sglang.srt.utils import broadcast_pyobj - broadcast_pyobj( + broadcast_result = broadcast_pyobj( broadcast_obj, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], ) + return broadcast_result def recv_requests_from_http(self) -> List[Request]: """Receives requests from http frontend""" @@ -327,7 +328,7 @@ def recv_requests_from_http(self) -> List[Request]: else: recv_reqs = None if self.tp_size > 1: - self._tensor_parallel_broadcast_byobj(recv_reqs) + recv_reqs = self._tensor_parallel_broadcast_byobj(recv_reqs) if recv_reqs: logger.debug(f"Received {len(recv_reqs)} HTTP requests") return recv_reqs @@ -390,7 +391,7 @@ def recv_requests_from_peer(self) -> List[Request]: else: recv_reqs = None if self.tp_size > 1: - self._tensor_parallel_broadcast_byobj(recv_reqs) + recv_reqs = self._tensor_parallel_broadcast_byobj(recv_reqs) if recv_reqs: logger.debug(f"Received {len(recv_reqs)} peer requests") return recv_reqs @@ -795,17 +796,18 @@ def _handle_cuda_input_requests(self, requests: List[Request]): self.scheduler.enque_request(original_req) # detokenize and send to http server - req_dict = { - "prompt_tokens": len(req.input_ids), - "next_token_id": req.next_token_id, - "rid": req.request_id, - } - if req.next_token_id == self.tokenizer.eos_token_id: - req_dict["eos"] = True - if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: - req_dict["length"] = True - if hasattr(self, "send_to_ipc_socket"): - self.send_to_ipc_socket.send_pyobj(req_dict) + if self.tp_rank == 0: + req_dict = { + "prompt_tokens": len(req.input_ids), + "next_token_id": req.next_token_id, + "rid": req.request_id, + } + if req.next_token_id == self.tokenizer.eos_token_id: + req_dict["eos"] = True + if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: + req_dict["length"] = True + if hasattr(self, "send_to_ipc_socket"): + self.send_to_ipc_socket.send_pyobj(req_dict) else: raise TypeError(f"First peer received unexpected request type: {type(req)}") else: @@ -878,17 +880,18 @@ def _handle_input_requests(self, requests: List[Request]): self.scheduler.enque_request(original_req) # detokenize and send to http server - req_dict = { - "prompt_tokens": len(req.input_ids), - "next_token_id": req.next_token_id, - "rid": req.request_id, - } - if req.next_token_id == self.tokenizer.eos_token_id: - req_dict["eos"] = True - if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: - req_dict["length"] = True - if hasattr(self, "send_to_ipc_socket"): - self.send_to_ipc_socket.send_pyobj(req_dict) + if self.tp_rank == 0: + req_dict = { + "prompt_tokens": len(req.input_ids), + "next_token_id": req.next_token_id, + "rid": req.request_id, + } + if req.next_token_id == self.tokenizer.eos_token_id: + req_dict["eos"] = True + if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: + req_dict["length"] = True + if hasattr(self, "send_to_ipc_socket"): + self.send_to_ipc_socket.send_pyobj(req_dict) else: raise TypeError(f"First peer received unexpected request type: {type(req)}") @@ -1028,7 +1031,7 @@ def _prepare_next_batch_requests( batched_requests = None if self.tp_size > 1: - self._tensor_parallel_broadcast_byobj(batched_requests) + batched_requests = self._tensor_parallel_broadcast_byobj(batched_requests) return batched_requests def _process_batch_cuda( From cc12b7f1e2dbe2df3bad286d7c5507e4ecd9747e Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Fri, 7 Nov 2025 12:05:36 +0000 Subject: [PATCH 8/9] update --- src/parallax/server/executor.py | 44 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 898e2a0e..0bae8f5a 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -302,6 +302,14 @@ def _tensor_parallel_broadcast_byobj(self, broadcast_obj): src=self.tp_group.ranks[0], ) return broadcast_result + + def _join_requests(self, left_reqs: List[Request], right_reqs: List[Request]): + """Merge two request lists""" + if not left_reqs: + return right_reqs + if not right_reqs: + return left_reqs + return left_reqs + right_reqs def recv_requests_from_http(self) -> List[Request]: """Receives requests from http frontend""" @@ -327,10 +335,7 @@ def recv_requests_from_http(self) -> List[Request]: logger.exception(f"Error receiving http request: {e}") else: recv_reqs = None - if self.tp_size > 1: - recv_reqs = self._tensor_parallel_broadcast_byobj(recv_reqs) - if recv_reqs: - logger.debug(f"Received {len(recv_reqs)} HTTP requests") + return recv_reqs def recv_requests_from_peer(self) -> List[Request]: @@ -390,10 +395,7 @@ def recv_requests_from_peer(self) -> List[Request]: logger.exception(f"Error receiving or deserializing request: {e}") else: recv_reqs = None - if self.tp_size > 1: - recv_reqs = self._tensor_parallel_broadcast_byobj(recv_reqs) - if recv_reqs: - logger.debug(f"Received {len(recv_reqs)} peer requests") + return recv_reqs def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: @@ -827,8 +829,10 @@ def _handle_cuda_input_requests(self, requests: List[Request]): def _handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" - if not requests: + if self.tp_rank == 0 and not requests: return + if self.tp_size > 1: + requests = self._tensor_parallel_broadcast_byobj(requests) if len(requests) > 0: logger.debug(f"Handling {len(requests)} requests.") @@ -1030,8 +1034,6 @@ def _prepare_next_batch_requests( else: batched_requests = None - if self.tp_size > 1: - batched_requests = self._tensor_parallel_broadcast_byobj(batched_requests) return batched_requests def _process_batch_cuda( @@ -1168,31 +1170,33 @@ def run_loop(self): # 1. Ingest new requests from the http frontend if self.is_first_peer: http_requests = self.recv_requests_from_http() - self._handle_input_requests(http_requests) # 2. Ingest new requests from the RPC server incoming_requests = self.recv_requests_from_peer() - self._handle_input_requests(incoming_requests) - # 3. Send finished batch to next peer + # 3. Merge requests and handle requests + received_requests = self._join_requests(http_requests, incoming_requests) + self._handle_input_requests(received_requests) + + # 4. Send finished batch to next peer if len(self.finished_batch) > 0 and self.is_first_peer and self.tp_rank == 0: self.send_to_peer_socket.send_multipart( [b"abort", abort_request_to_proto(self.finished_batch).SerializeToString()] ) self.finished_batch = [] - # 4. Check if we should form a batch + # 5. Check if we should form a batch if not self.scheduler.should_dispatch(): time.sleep(0.01) # prevent busy waiting continue - # 5. Form a batch from the scheduler's queue + # 6. Form a batch from the scheduler's queue batch_to_process = self.scheduler.form_batch() if not batch_to_process: continue logger.debug(f"Formed batch with {len(batch_to_process)} requests.") - # 6. Process the batch + # 7. Process the batch try: prepared_inputs_dict = self._prepare_batch_inputs(batch_to_process) @@ -1220,15 +1224,15 @@ def run_loop(self): self._decode_steps_since_metric = 0 except Exception: pass - # 7. Prepare requests for the next stage in the pipeline + # 8. Prepare requests for the next stage in the pipeline next_batch = self._prepare_next_batch_requests( requests=prepared_inputs["requests"], hidden_states=output, lengths=prepared_inputs["lengths"], ) - # 8. Dispatch to the appropriate destination - if self.is_last_peer and self.is_first_peer: + # 9. Dispatch to the appropriate destination + if self.is_last_peer and self.is_first_peer and self.tp_rank == 0: # Single node: handle locally self._handle_input_requests(next_batch) elif self.tp_rank == 0: From 775619131e8c1835b4fab9c6f0f742ec821744c6 Mon Sep 17 00:00:00 2001 From: TianyiZhao1437 Date: Fri, 7 Nov 2025 20:12:17 +0800 Subject: [PATCH 9/9] fix pylint issues --- src/parallax/launch.py | 12 +++++------- src/parallax/server/executor.py | 2 +- src/parallax/server/server_args.py | 5 ++++- src/parallax/sglang/model_runner.py | 1 - src/parallax/utils/utils.py | 3 +-- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 5a05baf2..e9b1ad68 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -24,7 +24,7 @@ from parallax.server.executor import Executor from parallax.server.http_server import launch_http_server from parallax.server.server_args import parse_args -from parallax.utils.utils import get_current_device, fetch_model_from_hf +from parallax.utils.utils import fetch_model_from_hf, get_current_device from parallax_utils.ascii_anime import display_parallax_join from parallax_utils.logging_config import get_logger, set_log_level @@ -58,6 +58,7 @@ def run_executor_process(args): finally: executor.shutdown() + def terminate_subprocess(process): """Kill a subprocess""" logger.debug("Terminating subprocess...") @@ -67,6 +68,7 @@ def terminate_subprocess(process): except Exception as e: logger.error(f"Failed to terminate subprocess: {e}") + if __name__ == "__main__": multiprocessing.set_start_method("spawn", force=True) @@ -190,15 +192,11 @@ def terminate_subprocess(process): finally: thread_pool = [] for executor_proc in executor_procs: - t = threading.Thread( - target=terminate_subprocess, args=(executor_proc,) - ) + t = threading.Thread(target=terminate_subprocess, args=(executor_proc,)) t.start() thread_pool.append(t) if http_server_process is not None: - t = threading.Thread( - target=terminate_subprocess, args=(http_server_process,) - ) + t = threading.Thread(target=terminate_subprocess, args=(http_server_process,)) t.start() thread_pool.append(t) if gradient_server is not None: diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 0bae8f5a..efab0bd1 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -302,7 +302,7 @@ def _tensor_parallel_broadcast_byobj(self, broadcast_obj): src=self.tp_group.ranks[0], ) return broadcast_result - + def _join_requests(self, left_reqs: List[Request], right_reqs: List[Request]): """Merge two request lists""" if not left_reqs: diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 2f47e18a..ae6a0290 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -154,7 +154,10 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size") parser.add_argument( - "--nccl-port", type=int, default=4001, help="The port for NCCL distributed environment setup." + "--nccl-port", + type=int, + default=4001, + help="The port for NCCL distributed environment setup.", ) # Logging and debugging diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 0429f887..c195cc65 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -6,7 +6,6 @@ import logging import os -import random from typing import Any, Dict, List, Optional, Tuple, Union import sglang diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index 451e0c54..322de252 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -1,7 +1,5 @@ """Utility functions.""" -import json -from pathlib import Path from typing import List import mlx.core as mx @@ -11,6 +9,7 @@ import zmq from mlx_lm.utils import get_model_path, load_config + def is_cuda_available(): """Check backend supports cuda""" return torch.cuda.is_available()