diff --git a/docker/Dockerfile.blackwell b/docker/Dockerfile.blackwell index a0dff681..7eb6c397 100644 --- a/docker/Dockerfile.blackwell +++ b/docker/Dockerfile.blackwell @@ -1,6 +1,6 @@ -FROM lmsysorg/sglang:v0.5.3rc1 +FROM lmsysorg/sglang:v0.5.4.post1 -ENV SGL_ENABLE_JIT_DEEPGEMM=0 +ENV SGLANG_ENABLE_JIT_DEEPGEMM=0 WORKDIR /parallax @@ -9,4 +9,3 @@ COPY src ./src COPY pyproject.toml ./pyproject.toml RUN pip install -e '.[gpu]' -RUN pip install https://github.com/sgl-project/whl/releases/download/v0.3.7/sgl_kernel-0.3.7+cu128-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall diff --git a/docker/Dockerfile.hopper b/docker/Dockerfile.hopper index c18bb662..cea3e5a3 100755 --- a/docker/Dockerfile.hopper +++ b/docker/Dockerfile.hopper @@ -1,6 +1,6 @@ -FROM lmsysorg/sglang:v0.5.3rc1 +FROM lmsysorg/sglang:v0.5.4.post1 -ENV SGL_ENABLE_JIT_DEEPGEMM=0 +ENV SGLANG_ENABLE_JIT_DEEPGEMM=0 WORKDIR /parallax diff --git a/src/backend/server/rpc_connection_handler.py b/src/backend/server/rpc_connection_handler.py index 88548bba..67e4d2cc 100644 --- a/src/backend/server/rpc_connection_handler.py +++ b/src/backend/server/rpc_connection_handler.py @@ -153,6 +153,7 @@ def get_layer_allocation(self, current_node_id): ), "start_layer": start_layer, "end_layer": end_layer, + "tp_size": node.hardware.num_gpus, } return {} @@ -182,6 +183,7 @@ def build_node(self, node_json: dict): def build_hardware(self, hardware_json): node_id = hardware_json.get("node_id") + num_gpus = hardware_json.get("num_gpus") tflops_fp16 = hardware_json.get("tflops_fp16") gpu_name = hardware_json.get("gpu_name") memory_gb = hardware_json.get("memory_gb") @@ -189,6 +191,7 @@ def build_hardware(self, hardware_json): device = hardware_json.get("device") return NodeHardwareInfo( node_id=node_id, + num_gpus=num_gpus, tflops_fp16=tflops_fp16, gpu_name=gpu_name, memory_gb=memory_gb, diff --git a/src/backend/server/scheduler_manage.py b/src/backend/server/scheduler_manage.py index 99784f4d..b0796ef9 100644 --- a/src/backend/server/scheduler_manage.py +++ b/src/backend/server/scheduler_manage.py @@ -115,6 +115,7 @@ def build_node_info(self, node): return { "node_id": node.node_id, "status": NODE_STATUS_AVAILABLE if node.is_active else NODE_STATUS_WAITING, + "gpu_num": node.hardware.num_gpus, "gpu_name": node.hardware.gpu_name, "gpu_memory": node.hardware.memory_gb, } diff --git a/src/parallax/cli.py b/src/parallax/cli.py index 7b117608..7cb10e3a 100644 --- a/src/parallax/cli.py +++ b/src/parallax/cli.py @@ -226,7 +226,7 @@ def join_command(args, passthrough_args: list[str] | None = None): # Set environment variable for the subprocess env = os.environ.copy() - env["SGL_ENABLE_JIT_DEEPGEMM"] = "0" + env["SGLANG_ENABLE_JIT_DEEPGEMM"] = "0" # Build the command to run the launch.py script passthrough_args = passthrough_args or [] diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 5be4cd0e..168d4d83 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -2,7 +2,11 @@ Launch the Parallax server. This script is used to launch the Parallax server. -It will start the P2P server and the executor. +It will start the following services: + 1.Executor with tp_rank=0 in the main process. + 2.Executor with tp_rank>0, each tp_rank as a subprocess. + 3.HTTP server as a subprocess. + 4.P2P server as a thread in the main process. Example command: python src/parallax/launch.py \ @@ -41,7 +45,7 @@ gradient_server = None http_server_process = None executor = None - executor_procs = [] + executor_subprocs = [] try: args = parse_args() set_log_level(args.log_level) @@ -75,6 +79,7 @@ pp_start_layer=args.start_layer, pp_end_layer=args.end_layer, hidden_layers=config.get("num_hidden_layers"), + tp_size=args.tp_size, tcp_port=args.tcp_port, udp_port=args.udp_port, dht_prefix=args.dht_prefix, @@ -91,8 +96,9 @@ ) if gradient_server is not None: gradient_server.status = ServerState.READY - tp_rank_range = range(args.tp_size) - for tp_rank in tp_rank_range: + + # For each tp_rank > 0, create a subprocess and run executor + for tp_rank in range(1, args.tp_size): args_copy = argparse.Namespace(**vars(args)) args_copy.tp_rank = tp_rank proc = multiprocessing.Process( @@ -100,9 +106,11 @@ args=(args_copy,), ) proc.start() - executor_procs.append(proc) - for executor_process in executor_procs: - executor_process.join() + executor_subprocs.append(proc) + # Launch executor with tp_rank=0 in the main process + args.tp_rank = 0 + executor = Executor.create_from_args(args) + executor.run_loop() else: gradient_server = launch_p2p_server( initial_peers=args.initial_peers, @@ -111,6 +119,7 @@ pp_start_layer=args.start_layer, pp_end_layer=args.end_layer, hidden_layers=None, + tp_size=args.tp_size, tcp_port=args.tcp_port, udp_port=args.udp_port, dht_prefix=args.dht_prefix, @@ -128,9 +137,7 @@ args.start_layer = gradient_server.block_start_index args.end_layer = gradient_server.block_end_index args.model_path = gradient_server.model_name - # TODO: Implement inter-process communication to enable TP. - # For scheduler mode, currently only support tp_rank=0 - args.tp_rank = 0 + args.tp_size = gradient_server.tp_size logger.debug( f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}" @@ -148,6 +155,18 @@ # Main execution loop with layer reallocation support while True: try: + # For each tp_rank > 0, create a subprocess and run executor + for tp_rank in range(1, args.tp_size): + args_copy = argparse.Namespace(**vars(args)) + args_copy.tp_rank = tp_rank + proc = multiprocessing.Process( + target=run_executor_process, + args=(args_copy,), + ) + proc.start() + executor_subprocs.append(proc) + # Launch executor with tp_rank=0 in the main process + args.tp_rank = 0 executor = Executor.create_from_args(args, gradient_server=gradient_server) if gradient_server is not None: gradient_server.status = ServerState.READY @@ -159,7 +178,18 @@ logger.warning( "Layer allocation changed! Reloading executor with new layers..." ) + + # shutdown all executor processes + thread_pool = [] + for executor_process in executor_subprocs: + t = threading.Thread( + target=stop_executor_process, args=(executor_process,) + ) + t.start() + thread_pool.append(t) executor.shutdown() + for t in thread_pool: + t.join() if args.start_layer == 0: http_server_process = stop_http_server(http_server_process) @@ -210,13 +240,13 @@ if gradient_server is not None: gradient_server.shutdown() - # Shutdown executor subprocess for scheduler mode - for executor_process in executor_procs: + # Shutdown executor subprocesses + for executor_process in executor_subprocs: t = threading.Thread(target=stop_executor_process, args=(executor_process,)) t.start() thread_pool.append(t) - # Shutdown executor main process for non-scheduler mode + # Shutdown executor main process if executor is not None: executor.shutdown() diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index 85e5d091..9f59fff9 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -201,6 +201,7 @@ def __init__( block_start_index: int = 0, block_end_index: int = 1, hidden_layers: int = 128, + tp_size: int = 1, dht_prefix: str = "gradient", host_maddrs: List[str] = [], http_port: Optional[int] = None, @@ -220,6 +221,7 @@ def __init__( self.block_start_index = block_start_index self.block_end_index = block_end_index self.hidden_layers = hidden_layers + self.tp_size = tp_size self.dht_prefix = dht_prefix self.host_maddrs = host_maddrs self.announce_maddrs = announce_maddrs @@ -346,6 +348,7 @@ def run(self): self.block_start_index = response.get("start_layer") self.block_end_index = response.get("end_layer") self.model_name = response.get("model_name") + self.tp_size = response.get("tp_size") # Publish executor metrics to backend on each update def _publish_metrics(_snapshot): @@ -738,6 +741,7 @@ def launch_p2p_server( pp_start_layer: int, pp_end_layer: int, hidden_layers: int, + tp_size: int, tcp_port: int, udp_port: int, dht_prefix: str, @@ -761,6 +765,7 @@ def launch_p2p_server( block_start_index=pp_start_layer, block_end_index=pp_end_layer, hidden_layers=hidden_layers, + tp_size=tp_size, dht_prefix=dht_prefix, host_maddrs=[f"/ip4/0.0.0.0/tcp/{tcp_port}", f"/ip4/0.0.0.0/udp/{udp_port}/quic-v1"], announce_maddrs=announce_maddrs, diff --git a/src/parallax/server/server_info.py b/src/parallax/server/server_info.py index ac66a4db..0e2234e9 100644 --- a/src/parallax/server/server_info.py +++ b/src/parallax/server/server_info.py @@ -32,6 +32,7 @@ class HardwareInfo: total_ram_gb: float chip: str tflops_fp16: float + num_gpus: int def dumps(self) -> Dict[str, Any]: """Serializes the HardwareInfo object to a dictionary.""" @@ -99,7 +100,7 @@ def detect(cls) -> "AppleSiliconHardwareInfo": "Please add it to the _APPLE_PEAK_FP16 dictionary." ) from e - return cls(total_ram_gb=round(total_gb, 1), chip=chip, tflops_fp16=flops) + return cls(num_gpus=1, total_ram_gb=round(total_gb, 1), chip=chip, tflops_fp16=flops) @dataclass @@ -143,6 +144,7 @@ def detect(cls) -> "NvidiaHardwareInfo": if torch is None or not torch.cuda.is_available(): raise RuntimeError("CUDA not available; cannot detect NVIDIA hardware") + device_count = torch.cuda.device_count() device_index = torch.cuda.current_device() props = torch.cuda.get_device_properties(device_index) name = getattr(props, "name", f"cuda:{device_index}") @@ -156,6 +158,7 @@ def detect(cls) -> "NvidiaHardwareInfo": spec = cls._match_gpu_specs(name, total_vram_gb) return cls( + num_gpus=device_count, total_ram_gb=round(total_gb, 1), chip=name, tflops_fp16=float(spec["tflops_fp16"]), @@ -179,6 +182,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]: # Fallback to a conservative default return { "node_id": node_id, + "num_gpus": 1, "tflops_fp16": 50.0, "gpu_name": "Unknown", "memory_gb": 16.0, @@ -189,6 +193,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]: if isinstance(hw, NvidiaHardwareInfo): return { "node_id": node_id, + "num_gpus": hw.num_gpus, "tflops_fp16": hw.tflops_fp16, "gpu_name": hw.chip, "memory_gb": hw.vram_gb, @@ -200,6 +205,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]: est_bandwidth = 100.0 return { "node_id": node_id, + "num_gpus": hw.num_gpus, "tflops_fp16": hw.tflops_fp16, "gpu_name": hw.chip, "memory_gb": hw.total_ram_gb, @@ -209,6 +215,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]: # Generic fallback return { "node_id": node_id, + "num_gpus": hw.num_gpus, "tflops_fp16": hw.tflops_fp16, "gpu_name": "Unknown", "memory_gb": 16.0, diff --git a/src/scheduling/README.md b/src/scheduling/README.md index 13b1cf13..c7b37e75 100644 --- a/src/scheduling/README.md +++ b/src/scheduling/README.md @@ -117,13 +117,13 @@ model = ModelInfo( # instantiate with your model's parameters n0 = Node( node_id="node-0", - hardware=NodeHardwareInfo(node_id="node-0", tflops_fp16=180.0, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0), + hardware=NodeHardwareInfo(node_id="node-0", tflops_fp16=180.0, num_gpus=1, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0), model_info=model, ) n1 = Node( node_id="node-1", - hardware=NodeHardwareInfo(node_id="node-1", tflops_fp16=180.0, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0), + hardware=NodeHardwareInfo(node_id="node-1", tflops_fp16=180.0, num_gpus=1, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0), model_info=model, ) diff --git a/src/scheduling/layer_allocation.py b/src/scheduling/layer_allocation.py index cee8799c..9c305ec1 100644 --- a/src/scheduling/layer_allocation.py +++ b/src/scheduling/layer_allocation.py @@ -287,7 +287,9 @@ def should_global_rebalance(self) -> bool: if len(layer_heap) < 2: return False - total_cluster_memory = sum(node.hardware.memory_gb for node in self.nodes) + total_cluster_memory = sum( + (node.hardware.num_gpus * node.hardware.memory_gb) for node in self.nodes + ) if total_cluster_memory == 0: raise ValueError("Total cluster memory is zero") diff --git a/src/scheduling/node.py b/src/scheduling/node.py index 6603ed53..cb597dcd 100644 --- a/src/scheduling/node.py +++ b/src/scheduling/node.py @@ -31,6 +31,7 @@ class NodeHardwareInfo: """ node_id: str + num_gpus: int tflops_fp16: float gpu_name: str memory_gb: float @@ -272,7 +273,12 @@ def get_decoder_layer_capacity( Capacity is measured using the parameter memory budget on the device. """ available_memory_bytes = floor( - self.hardware.memory_gb * 1024 * 1024 * 1024 * self.param_hosting_ratio + self.hardware.num_gpus + * self.hardware.memory_gb + * 1024 + * 1024 + * 1024 + * self.param_hosting_ratio ) if include_input_embed: available_memory_bytes -= self.model_info.embedding_io_bytes @@ -300,7 +306,14 @@ def per_decoder_layer_kv_cache_memory(self) -> Optional[int]: if self.num_current_layers == 0: return None return floor( - (self.hardware.memory_gb * 1024 * 1024 * 1024 * self.kv_cache_ratio) + ( + self.hardware.num_gpus + * self.hardware.memory_gb + * 1024 + * 1024 + * 1024 + * self.kv_cache_ratio + ) / self.num_current_layers ) diff --git a/tests/scheduler_tests/test_layer_allocation.py b/tests/scheduler_tests/test_layer_allocation.py index 972950f0..46afa6f0 100644 --- a/tests/scheduler_tests/test_layer_allocation.py +++ b/tests/scheduler_tests/test_layer_allocation.py @@ -26,10 +26,10 @@ def _build_node(gpu_type: str, model: ModelInfo, id_suffix: str = "") -> Node: hw_map = { - "a100-80g": NodeHardwareInfo("a100-80g" + id_suffix, 312.0, "", 80.0, 2039.0, "cuda"), - "a100-40g": NodeHardwareInfo("a100-40g" + id_suffix, 312.0, "", 40.0, 1935.0, "cuda"), - "rtx5090": NodeHardwareInfo("rtx5090" + id_suffix, 165, "", 32.0, 1792.0, "cuda"), - "rtx4090": NodeHardwareInfo("rtx4090" + id_suffix, 82.6, "", 24.0, 1008.0, "cuda"), + "a100-80g": NodeHardwareInfo("a100-80g" + id_suffix, 1, 312.0, "", 80.0, 2039.0, "cuda"), + "a100-40g": NodeHardwareInfo("a100-40g" + id_suffix, 1, 312.0, "", 40.0, 1935.0, "cuda"), + "rtx5090": NodeHardwareInfo("rtx5090" + id_suffix, 1, 165, "", 32.0, 1792.0, "cuda"), + "rtx4090": NodeHardwareInfo("rtx4090" + id_suffix, 1, 82.6, "", 24.0, 1008.0, "cuda"), } hw = hw_map[gpu_type] return Node(node_id=hw.node_id, hardware=hw, model_info=model) diff --git a/tests/scheduler_tests/test_utils.py b/tests/scheduler_tests/test_utils.py index efdb549f..33e767a6 100644 --- a/tests/scheduler_tests/test_utils.py +++ b/tests/scheduler_tests/test_utils.py @@ -12,6 +12,7 @@ A100_80G = NodeHardwareInfo( node_id="a100-80g", + num_gpus=1, tflops_fp16=312.0, gpu_name="", memory_gb=80.0, @@ -20,6 +21,7 @@ ) A100_40G = NodeHardwareInfo( node_id="a100-40g", + num_gpus=1, tflops_fp16=312.0, gpu_name="", memory_gb=40.0, @@ -28,6 +30,7 @@ ) RTX5090 = NodeHardwareInfo( node_id="rtx5090", + num_gpus=1, tflops_fp16=104.8, gpu_name="", memory_gb=32.0, @@ -36,6 +39,7 @@ ) RTX4090 = NodeHardwareInfo( node_id="rtx4090", + num_gpus=1, tflops_fp16=82.6, gpu_name="", memory_gb=24.0, @@ -78,6 +82,7 @@ def build_node( """Create a `Node` with hardware info and attach test-only coordinates/bandwidth.""" hw = NodeHardwareInfo( node_id=node_id, + num_gpus=1, tflops_fp16=tflops, gpu_name="", memory_gb=mem_gb,