diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 46f1368c..5be4cd0e 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -13,15 +13,21 @@ --end-layer 28 """ +import argparse import multiprocessing import os import tempfile import threading from parallax.p2p.server import ServerState, launch_p2p_server -from parallax.server.executor import Executor +from parallax.server.executor import ( + Executor, + run_executor_process, + stop_executor_process, +) from parallax.server.http_server import launch_http_server, stop_http_server from parallax.server.server_args import parse_args +from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port from parallax_utils.ascii_anime import display_parallax_join from parallax_utils.logging_config import get_logger, set_log_level from parallax_utils.version_check import check_latest_release @@ -35,6 +41,7 @@ gradient_server = None http_server_process = None executor = None + executor_procs = [] try: args = parse_args() set_log_level(args.log_level) @@ -43,28 +50,31 @@ args.send_to_peer_addr = f"ipc://{tempfile.NamedTemporaryFile().name}" args.executor_input_ipc = f"ipc://{tempfile.NamedTemporaryFile().name}" args.executor_output_ipc = f"ipc://{tempfile.NamedTemporaryFile().name}" + if args.nccl_port is None: + args.nccl_port = initialize_nccl_port() # Silence tokenizer warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" logger.debug(f"executor_input_addr: {args.executor_input_ipc}") logger.debug(f"executor_output_addr: {args.executor_output_ipc}") + logger.debug(f"nccl_port: {args.nccl_port}") if args.scheduler_addr is None: if args.log_level != "DEBUG": display_parallax_join(args.model_path) check_latest_release() + config = fetch_model_from_hf(args.model_path) # only launch http server on head node if args.start_layer == 0: http_server_process = launch_http_server(args) - executor = Executor.create_from_args(args) launch_p2p_server( initial_peers=args.initial_peers, scheduler_addr=args.scheduler_addr, relay_servers=args.relay_servers, pp_start_layer=args.start_layer, pp_end_layer=args.end_layer, - hidden_layers=executor.config.get("num_hidden_layers"), + hidden_layers=config.get("num_hidden_layers"), tcp_port=args.tcp_port, udp_port=args.udp_port, dht_prefix=args.dht_prefix, @@ -81,7 +91,18 @@ ) if gradient_server is not None: gradient_server.status = ServerState.READY - executor.run_loop() + tp_rank_range = range(args.tp_size) + for tp_rank in tp_rank_range: + args_copy = argparse.Namespace(**vars(args)) + args_copy.tp_rank = tp_rank + proc = multiprocessing.Process( + target=run_executor_process, + args=(args_copy,), + ) + proc.start() + executor_procs.append(proc) + for executor_process in executor_procs: + executor_process.join() else: gradient_server = launch_p2p_server( initial_peers=args.initial_peers, @@ -107,7 +128,10 @@ args.start_layer = gradient_server.block_start_index args.end_layer = gradient_server.block_end_index args.model_path = gradient_server.model_name - # Hard code for mlx-community models + # TODO: Implement inter-process communication to enable TP. + # For scheduler mode, currently only support tp_rank=0 + args.tp_rank = 0 + logger.debug( f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}" ) @@ -174,13 +198,27 @@ except Exception as e: logger.exception(e) finally: - t = None + thread_pool = [] + + # Shutdown http server if http_server_process is not None: t = threading.Thread(target=stop_http_server, args=(http_server_process,)) t.start() + thread_pool.append(t) + + # Shutdown gradient server if gradient_server is not None: gradient_server.shutdown() + + # Shutdown executor subprocess for scheduler mode + for executor_process in executor_procs: + t = threading.Thread(target=stop_executor_process, args=(executor_process,)) + t.start() + thread_pool.append(t) + + # Shutdown executor main process for non-scheduler mode if executor is not None: executor.shutdown() - if t is not None: + + for t in thread_pool: t.join() diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index dead8b84..933c397f 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -97,6 +97,10 @@ def __init__( # GPU/SGLang Specialized Configs attention_backend: Optional[str] = "torch_native", moe_runner_backend: Optional[str] = "auto", + # Tensor Parallel Configs + tp_rank: Optional[int] = 0, + tp_size: Optional[int] = 1, + nccl_port: Optional[int] = None, # Optional gradient server for layer reallocation detection gradient_server: Optional[Any] = None, ): @@ -121,10 +125,15 @@ def __init__( attention_backend, kv_block_size, moe_runner_backend, + tp_rank, + tp_size, + nccl_port, ) logger.debug( f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" ) + self.tp_group = self.model_runner.tp_group + self.tp_cpu_group = self.tp_group.cpu_group # SGL KV Cache Manager is already initialized in ScheduleBatch # TODO: Replace ScheduleBatch to Parallax inflight batch self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) @@ -153,6 +162,8 @@ def __init__( self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == self.config.get("num_hidden_layers") self.num_shard_layers = end_layer - start_layer + self.tp_size = tp_size + self.tp_rank = tp_rank # Metrics throttling for per-layer latency updates self.layer_latency_update_every = int(max(1, layer_latency_update_every)) @@ -265,108 +276,135 @@ def __init__( ) # Communication Related - self.zmq_context = zmq.Context() - if recv_from_peer_addr: - self.recv_from_peer_socket = get_zmq_socket( - self.zmq_context, zmq.PULL, recv_from_peer_addr, bind=False - ) - if send_to_peer_addr: - self.send_to_peer_socket = get_zmq_socket( - self.zmq_context, zmq.PUSH, send_to_peer_addr, bind=False - ) - if executor_input_ipc_addr: - self.recv_from_ipc_socket = get_zmq_socket( - self.zmq_context, zmq.PULL, executor_input_ipc_addr, bind=False - ) - if executor_output_ipc_addr: - self.send_to_ipc_socket = get_zmq_socket( - self.zmq_context, zmq.PUSH, executor_output_ipc_addr, bind=False - ) + if self.tp_rank == 0: + self.zmq_context = zmq.Context() + if recv_from_peer_addr: + self.recv_from_peer_socket = get_zmq_socket( + self.zmq_context, zmq.PULL, recv_from_peer_addr, bind=False + ) + if send_to_peer_addr: + self.send_to_peer_socket = get_zmq_socket( + self.zmq_context, zmq.PUSH, send_to_peer_addr, bind=False + ) + if executor_input_ipc_addr: + self.recv_from_ipc_socket = get_zmq_socket( + self.zmq_context, zmq.PULL, executor_input_ipc_addr, bind=False + ) + if executor_output_ipc_addr: + self.send_to_ipc_socket = get_zmq_socket( + self.zmq_context, zmq.PUSH, executor_output_ipc_addr, bind=False + ) @classmethod def create_from_args(cls, args: argparse.Namespace, gradient_server=None): """Create executor from command line arguments.""" return cls(**create_executor_config(args, gradient_server)) + def _tensor_parallel_broadcast_byobj(self, broadcast_obj): + """Wrapper for broadcast pyobject in TP group""" + from sglang.srt.utils import broadcast_pyobj + + broadcast_result = broadcast_pyobj( + broadcast_obj, + self.tp_group.rank, + self.tp_cpu_group, + src=self.tp_group.ranks[0], + ) + return broadcast_result + + def _join_requests(self, left_reqs: List[Request], right_reqs: List[Request]): + """Merge two request lists""" + if not left_reqs: + return right_reqs + if not right_reqs: + return left_reqs + return left_reqs + right_reqs + def recv_requests_from_http(self) -> List[Request]: """Receives requests from http frontend""" - recv_reqs = [] - while True: - try: - raw_request = self.recv_from_ipc_socket.recv_pyobj(zmq.NOBLOCK) + if self.tp_rank == 0: + recv_reqs = [] + while True: + try: + raw_request = self.recv_from_ipc_socket.recv_pyobj(zmq.NOBLOCK) + + # Check if this is an abort request + if isinstance(raw_request, dict) and raw_request.get("type") == "abort": + logger.debug( + f"Received abort request from HTTP for request ID: {raw_request.get('rid')}" + ) + self.scheduler.cancel_request(raw_request.get("rid")) + else: + # Normal request processing - do tokenization and form InitialRequest + req = self._handle_raw_request(raw_request) + recv_reqs.append(req) + except zmq.ZMQError: + break + except Exception as e: + logger.exception(f"Error receiving http request: {e}") + else: + recv_reqs = None - # Check if this is an abort request - if isinstance(raw_request, dict) and raw_request.get("type") == "abort": - logger.debug( - f"Received abort request from HTTP for request ID: {raw_request.get('rid')}" - ) - self.scheduler.cancel_request(raw_request.get("rid")) - else: - # Normal request processing - do tokenization and form InitialRequest - req = self._handle_raw_request(raw_request) - recv_reqs.append(req) - except zmq.ZMQError: - break - except Exception as e: - logger.exception(f"Error receiving http request: {e}") - if recv_reqs: - logger.debug(f"Received {len(recv_reqs)} HTTP requests") return recv_reqs def recv_requests_from_peer(self) -> List[Request]: """Receives requests from the RPC server.""" - recv_reqs = [] - while True: - try: - recv_req = self.recv_from_peer_socket.recv_multipart(zmq.NOBLOCK) - assert len(recv_req) == 2, f"Received invalid request: {recv_req}" - if recv_req[0] == b"forward": - # Create a new ForwardRequest instance and parse from bytes - forward_request = forward_pb2.ForwardRequest() - forward_request.ParseFromString(recv_req[1]) - recv_req = proto_to_request(forward_request, self.device) - - # Convert hidden_states dtype if necessary - if recv_req is not None and len(recv_req) > 0: - for req in recv_req: - if req.hidden_states is not None: - if req.hidden_states.dtype != self.dtype: - logger.debug( - f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" - ) - if self.device == "cuda": - req.hidden_states = req.hidden_states.to(self.dtype) - elif self.device == "mlx": - req.hidden_states = req.hidden_states.astype(self.dtype) - else: - raise ValueError(f"Unsupported device type: {self.device}") - - # Move current position for first peer - if self.is_first_peer: - for req in recv_req: - req.current_position += 1 - recv_reqs.extend(recv_req) - elif recv_req[0] == b"abort": - abort_request = forward_pb2.AbortRequest() - abort_request.ParseFromString(recv_req[1]) - recv_req = proto_to_abort_request(abort_request) - recv_reqs.extend(recv_req) - else: - raise ValueError(f"Unknown request type: {recv_req[0]}") - # First peer is responsible for tokenization - # if self.is_first_peer and isinstance(recv_req, InitialRequest): - # recv_req.input_ids = self.tokenizer.encode(recv_req.prompt) - # recv_req.prompt_len = len(recv_req.input_ids) - # recv_req.max_total_length = min( - # recv_req.max_total_length, recv_req.prompt_len + recv_req.max_new_tokens - # ) - - except zmq.ZMQError: - break - except Exception as e: - logger.exception(f"Error receiving or deserializing request: {e}") - if recv_reqs: - logger.debug(f"Received {len(recv_reqs)} peer requests") + if self.tp_rank == 0: + recv_reqs = [] + while True: + try: + recv_req = self.recv_from_peer_socket.recv_multipart(zmq.NOBLOCK) + assert len(recv_req) == 2, f"Received invalid request: {recv_req}" + if recv_req[0] == b"forward": + # Create a new ForwardRequest instance and parse from bytes + forward_request = forward_pb2.ForwardRequest() + forward_request.ParseFromString(recv_req[1]) + recv_req = proto_to_request(forward_request, self.device) + + # Convert hidden_states dtype if necessary + if recv_req is not None and len(recv_req) > 0: + for req in recv_req: + if req.hidden_states is not None: + if req.hidden_states.dtype != self.dtype: + logger.debug( + f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" + ) + if self.device == "cuda": + req.hidden_states = req.hidden_states.to(self.dtype) + elif self.device == "mlx": + req.hidden_states = req.hidden_states.astype(self.dtype) + else: + raise ValueError( + f"Unsupported device type: {self.device}" + ) + + # Move current position for first peer + if self.is_first_peer: + for req in recv_req: + req.current_position += 1 + recv_reqs.extend(recv_req) + elif recv_req[0] == b"abort": + abort_request = forward_pb2.AbortRequest() + abort_request.ParseFromString(recv_req[1]) + recv_req = proto_to_abort_request(abort_request) + recv_reqs.extend(recv_req) + else: + raise ValueError(f"Unknown request type: {recv_req[0]}") + # First peer is responsible for tokenization + # if self.is_first_peer and isinstance(recv_req, InitialRequest): + # recv_req.input_ids = self.tokenizer.encode(recv_req.prompt) + # recv_req.prompt_len = len(recv_req.input_ids) + # recv_req.max_total_length = min( + # recv_req.max_total_length, recv_req.prompt_len + recv_req.max_new_tokens + # ) + + except zmq.ZMQError: + break + except Exception as e: + logger.exception(f"Error receiving or deserializing request: {e}") + else: + recv_reqs = None + return recv_reqs def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: @@ -769,17 +807,18 @@ def _handle_cuda_input_requests(self, requests: List[Request]): self.scheduler.enque_request(original_req) # detokenize and send to http server - req_dict = { - "prompt_tokens": len(req.input_ids), - "next_token_id": req.next_token_id, - "rid": req.request_id, - } - if req.next_token_id == self.tokenizer.eos_token_id: - req_dict["eos"] = True - if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: - req_dict["length"] = True - if hasattr(self, "send_to_ipc_socket"): - self.send_to_ipc_socket.send_pyobj(req_dict) + if self.tp_rank == 0: + req_dict = { + "prompt_tokens": len(req.input_ids), + "next_token_id": req.next_token_id, + "rid": req.request_id, + } + if req.next_token_id == self.tokenizer.eos_token_id: + req_dict["eos"] = True + if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: + req_dict["length"] = True + if hasattr(self, "send_to_ipc_socket"): + self.send_to_ipc_socket.send_pyobj(req_dict) else: raise TypeError(f"First peer received unexpected request type: {type(req)}") else: @@ -799,10 +838,12 @@ def _handle_cuda_input_requests(self, requests: List[Request]): def _handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" + if self.tp_rank == 0 and not requests: + return + if self.tp_size > 1: + requests = self._tensor_parallel_broadcast_byobj(requests) if len(requests) > 0: logger.debug(f"Handling {len(requests)} requests.") - if not requests: - return if self.device == "cuda": self._handle_cuda_input_requests(requests) @@ -850,17 +891,18 @@ def _handle_input_requests(self, requests: List[Request]): self.scheduler.enque_request(original_req) # detokenize and send to http server - req_dict = { - "prompt_tokens": len(req.input_ids), - "next_token_id": req.next_token_id, - "rid": req.request_id, - } - if req.next_token_id == self.tokenizer.eos_token_id: - req_dict["eos"] = True - if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: - req_dict["length"] = True - if hasattr(self, "send_to_ipc_socket"): - self.send_to_ipc_socket.send_pyobj(req_dict) + if self.tp_rank == 0: + req_dict = { + "prompt_tokens": len(req.input_ids), + "next_token_id": req.next_token_id, + "rid": req.request_id, + } + if req.next_token_id == self.tokenizer.eos_token_id: + req_dict["eos"] = True + if original_req.status == RequestStatus.FINISHED_MAX_LENGTH: + req_dict["length"] = True + if hasattr(self, "send_to_ipc_socket"): + self.send_to_ipc_socket.send_pyobj(req_dict) else: raise TypeError(f"First peer received unexpected request type: {type(req)}") @@ -967,32 +1009,35 @@ def _prepare_next_batch_requests( self, requests: List[Request], hidden_states: Any, lengths: Any ) -> List[Request]: """Prepares a batch of requests for the next stage of the pipeline.""" - batched_requests = [] - pre_length = 0 - for i, src_request in enumerate(requests): - if self.is_last_peer: - # Last peer gets a 1D array of token IDs - hidden_state_for_req = hidden_states[i : i + 1] - else: - # Other peers get a 3D array of hidden states - if src_request.is_prefill: - true_length = int(lengths[i]) - if hidden_states.ndim == 3: - hidden_state_for_req = hidden_states[i, :true_length, :] - else: - hidden_state_for_req = hidden_states[ - pre_length : pre_length + true_length, : - ] - pre_length += true_length + if self.tp_rank == 0: + batched_requests = [] + pre_length = 0 + for i, src_request in enumerate(requests): + if self.is_last_peer: + # Last peer gets a 1D array of token IDs + hidden_state_for_req = hidden_states[i : i + 1] else: - if hidden_states.ndim == 3: - hidden_state_for_req = hidden_states[i, :, :] + # Other peers get a 3D array of hidden states + if src_request.is_prefill: + true_length = int(lengths[i]) + if hidden_states.ndim == 3: + hidden_state_for_req = hidden_states[i, :true_length, :] + else: + hidden_state_for_req = hidden_states[ + pre_length : pre_length + true_length, : + ] + pre_length += true_length else: - hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] - pre_length += 1 + if hidden_states.ndim == 3: + hidden_state_for_req = hidden_states[i, :, :] + else: + hidden_state_for_req = hidden_states[pre_length : pre_length + 1, :] + pre_length += 1 - next_req = self._prepare_next_single_request(src_request, hidden_state_for_req) - batched_requests.append(next_req) + next_req = self._prepare_next_single_request(src_request, hidden_state_for_req) + batched_requests.append(next_req) + else: + batched_requests = None return batched_requests @@ -1154,14 +1199,16 @@ def run_loop(self): # 1. Ingest new requests from the http frontend if self.is_first_peer: http_requests = self.recv_requests_from_http() - self._handle_input_requests(http_requests) # 2. Ingest new requests from the RPC server incoming_requests = self.recv_requests_from_peer() - self._handle_input_requests(incoming_requests) - # 3. Send finished batch to next peer - if len(self.finished_batch) > 0 and self.is_first_peer: + # 3. Merge and handle requests + received_requests = self._join_requests(http_requests, incoming_requests) + self._handle_input_requests(received_requests) + + # 4. Send finished batch to next peer + if len(self.finished_batch) > 0 and self.is_first_peer and self.tp_rank == 0: self.send_to_peer_socket.send_multipart( [b"abort", abort_request_to_proto(self.finished_batch).SerializeToString()] ) @@ -1175,9 +1222,9 @@ def run_loop(self): self._should_stop = True break - # 4. Admit requests into running set up to capacity, then form batch + # 5. Admit requests into running set up to capacity, then form batch self.scheduler.admit_requests() - # 4.1 Check for request timeouts and abort timed out requests + # 5.1 Check for request timeouts and abort timed out requests try: timed_out_reqs = self.scheduler.get_timed_out_requests() if timed_out_reqs: @@ -1235,21 +1282,24 @@ def run_loop(self): ) # 8. Dispatch to the appropriate destination - if self.is_last_peer and self.is_first_peer: - # Single node: handle locally - self._handle_input_requests(next_batch) - else: - # Send output to next peer - self.send_to_peer_socket.send_multipart( - [ - b"forward", - request_to_proto(next_batch, self.device).SerializeToString(), - ] - ) - logger.debug( - f"Processed batch of type {batch_type} with {len(next_batch)} requests " - f"in {(time.time() - start_time) * 1000:.3f} ms" - ) + if self.tp_rank == 0: + if self.is_last_peer and self.is_first_peer: + # Single node: handle locally + self._handle_input_requests(next_batch) + else: + # Send output to next peer + self.send_to_peer_socket.send_multipart( + [ + b"forward", + request_to_proto( + next_batch, self.device + ).SerializeToString(), + ] + ) + logger.debug( + f"Processed batch of type {batch_type} with {len(next_batch)} requests " + f"in {(time.time() - start_time) * 1000:.3f} ms" + ) except Exception as e: logger.exception(f"Error processing batch: {e}") @@ -1281,17 +1331,41 @@ def shutdown(self): pass try: - self.recv_from_peer_socket.close() - self.send_to_peer_socket.close() - self.recv_from_ipc_socket.close() - self.send_to_ipc_socket.close() - self.zmq_context.term() + if self.tp_rank == 0: + self.recv_from_peer_socket.close() + self.send_to_peer_socket.close() + self.recv_from_ipc_socket.close() + self.send_to_ipc_socket.close() + self.zmq_context.term() except Exception as e: logger.debug(f"Error closing sockets (may already be closed): {e}") logger.debug("Executor shutdown complete.") +def run_executor_process(args, gradient_server=None): + """Run executor as a subprocess""" + try: + executor = Executor.create_from_args(args, gradient_server) + executor.run_loop() + except KeyboardInterrupt: + logger.debug("Received interrupt signal, shutting down...") + except Exception as e: + logger.exception(e) + finally: + executor.shutdown() + + +def stop_executor_process(executor_process): + """Kill a subprocess""" + logger.debug("Terminating executor subprocess...") + try: + executor_process.kill() + executor_process.join() + except Exception as e: + logger.error(f"Failed to terminate executor subprocess: {e}") + + def create_executor_config(args: argparse.Namespace, gradient_server=None): """Create executor configuration from command line arguments.""" @@ -1315,6 +1389,9 @@ def create_executor_config(args: argparse.Namespace, gradient_server=None): "executor_output_ipc_addr": args.executor_output_ipc, "attention_backend": args.attention_backend, "moe_runner_backend": args.moe_runner_backend, + "tp_rank": args.tp_rank, + "tp_size": args.tp_size, + "nccl_port": args.nccl_port, "gradient_server": gradient_server, } return config diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 6e014b18..89a0558b 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -171,6 +171,16 @@ def parse_args() -> argparse.Namespace: help="Choose the GPU moe kernels", ) + # Tensor parallel configuration + parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size") + + parser.add_argument( + "--nccl-port", + type=int, + default=None, + help="The port for NCCL distributed environment setup", + ) + # Logging and debugging parser.add_argument( "--log-level", diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index ff8f9f4a..bef39146 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -6,7 +6,6 @@ import logging import os -import random import sglang import sglang.srt.distributed.parallel_state @@ -203,6 +202,7 @@ def init_torch_distributed(self): def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", + tp_size: int = 1, attention_backend: str = "flashinfer", kv_block_size: int = 64, moe_runner_backend="auto", @@ -215,6 +215,7 @@ def form_sgl_server_args( page_size=kv_block_size, mem_fraction_static=0.85, moe_runner_backend=moe_runner_backend, + tp_size=tp_size, ) return sgl_server_args @@ -227,6 +228,9 @@ def initialize_sgl_model_runner( attention_backend: str, kv_block_size: int, moe_runner_backend: str, + tp_rank: int, + tp_size: int, + nccl_port: int, ): """ Creates a SGL ModelRunner object. @@ -252,7 +256,6 @@ def initialize_sgl_model_runner( config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") - nccl_port = random.randint(4000, 5000) # Handling mxfp4 arguments quant_method = config.get("quant_method", None) @@ -271,6 +274,7 @@ def initialize_sgl_model_runner( server_args = form_sgl_server_args( str(model_path), dtype, + tp_size, attention_backend, kv_block_size, moe_runner_backend, @@ -296,9 +300,9 @@ def initialize_sgl_model_runner( model_runner = ParallaxModelRunner( model_config=model_config, mem_fraction_static=kv_cache_memory_fraction, - gpu_id=0, - tp_rank=0, - tp_size=1, + gpu_id=tp_rank, # Currently reuse tp_rank to only support TP. + tp_rank=tp_rank, + tp_size=tp_size, pp_rank=0, pp_size=1, moe_ep_rank=0, diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index c2cfff1f..5a3a2642 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -1,5 +1,7 @@ """Utility functions.""" +import random +import socket from typing import List import mlx.core as mx @@ -7,6 +9,7 @@ import psutil import torch import zmq +from mlx_lm.utils import get_model_path, load_config def is_cuda_available(): @@ -266,3 +269,40 @@ def combine_padding_and_causal_masks( padding_mask_float = (padding_mask - 1) * inf_value padding_mask_float = padding_mask_float.astype(dtype) return causal_mask + padding_mask_float + + +def fetch_model_from_hf(name: str): + """Fetch model from huggingface and returns model config""" + model_path = get_model_path(name)[0] + config = load_config(model_path) + return config + + +def is_port_available(port: int): + """ + Copied from SGLang. + Return whether a port is available. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except socket.error: + return False + except OverflowError: + return False + + +def initialize_nccl_port(): + """Initialize nccl port for GPU""" + nccl_port = random.randint(4000, 5000) + while True: + if is_port_available(nccl_port): + break + if nccl_port < 60000: + nccl_port += 42 + else: + nccl_port -= 43 + return nccl_port diff --git a/tests/test_server_args.py b/tests/test_server_args.py index deaf5f42..93601403 100644 --- a/tests/test_server_args.py +++ b/tests/test_server_args.py @@ -88,6 +88,9 @@ def test_create_config(self): executor_output_ipc="///ipc/2", attention_backend="torch_native", moe_runner_backend="auto", + tp_rank=0, + tp_size=1, + nccl_port=4001, ) config = create_executor_config(args)