diff --git a/src/parallax/launch.py b/src/parallax/launch.py index d12c8948..f41f5133 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -21,7 +21,7 @@ from common.version_check import check_latest_release from parallax.p2p.server import ServerState, launch_p2p_server from parallax.server.executor import Executor -from parallax.server.http_server import launch_http_server +from parallax.server.http_server import launch_http_server, stop_http_server from parallax.server.server_args import parse_args from parallax_utils.ascii_anime import display_parallax_join from parallax_utils.logging_config import get_logger, set_log_level @@ -79,6 +79,9 @@ param_hosting_ratio=args.param_hosting_ratio, kv_cache_ratio=args.kv_cache_ratio, ) + if gradient_server is not None: + gradient_server.status = ServerState.READY + executor.run_loop() else: gradient_server = launch_p2p_server( initial_peers=args.initial_peers, @@ -117,11 +120,55 @@ # only launch http server on head node if args.start_layer == 0: http_server_process = launch_http_server(args) - executor = Executor.create_from_args(args) - if gradient_server is not None: - gradient_server.status = ServerState.READY - executor.run_loop() + # Main execution loop with layer reallocation support + while True: + try: + executor = Executor.create_from_args(args, gradient_server=gradient_server) + if gradient_server is not None: + gradient_server.status = ServerState.READY + + executor.run_loop() + + # Check if layer allocation changed (executor exited due to reallocation) + if gradient_server is not None and gradient_server._layer_allocation_changed: + logger.warning( + "Layer allocation changed! Reloading executor with new layers..." + ) + executor.shutdown() + + if args.start_layer == 0: + http_server_process = stop_http_server(http_server_process) + if gradient_server.block_start_index == 0: + http_server_process = launch_http_server(args) + + # Update args with new layer allocation + args.start_layer = gradient_server.block_start_index + args.end_layer = gradient_server.block_end_index + if gradient_server.model_name: + args.model_path = gradient_server.model_name + + logger.info( + f"Creating new executor with layers [{args.start_layer}, {args.end_layer})" + ) + + gradient_server._layer_allocation_changed = False + continue # Create new executor in next iteration + else: + break # Normal exit + except KeyboardInterrupt: + logger.debug("Received interrupt signal, shutting down...") + break + except Exception as e: + logger.exception(f"Executor error: {e}") + # If layer allocation changed, try to reload + if gradient_server is not None and gradient_server._layer_allocation_changed: + logger.info("Attempting to reload executor after error...") + if executor is not None: + executor.shutdown() + continue + else: + raise except KeyboardInterrupt: logger.debug("Received interrupt signal, shutting down...") except Exception as e: @@ -129,20 +176,8 @@ finally: t = None if http_server_process is not None: - - def terminate_http_server_process(process): - logger.debug("Terminating HTTP server process...") - try: - process.kill() - process.join() - except Exception as e: - logger.error(f"Failed to terminate HTTP server process: {e}") - - if http_server_process is not None: - t = threading.Thread( - target=terminate_http_server_process, args=(http_server_process,) - ) - t.start() + t = threading.Thread(target=stop_http_server, args=(http_server_process,)) + t.start() if gradient_server is not None: gradient_server.shutdown() if executor is not None: diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index 54533e49..56e7babc 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -247,6 +247,7 @@ def __init__( self.announcer = None self.connection_handler = None self.stop_event = threading.Event() + self._layer_allocation_changed = False def build_lattica(self): self.lattica = Lattica.builder().with_listen_addrs(self.host_maddrs) @@ -584,6 +585,30 @@ def _announcer_thread(): f"Heartbeat: Node {self.lattica.peer_id()}... " f"Model: {model_name}, Layers: [{start_layer}, {end_layer})" ) + # Check if layer allocation changed + if ( + start_layer != self.block_start_index + or end_layer != self.block_end_index + ): + logger.warning( + f"Layer allocation changed! " + f"Current: [{self.block_start_index}, {self.block_end_index}) -> " + f"New: [{start_layer}, {end_layer})" + ) + # Update layer allocation + self.block_start_index = start_layer + self.block_end_index = end_layer + if model_name: + self.model_name = model_name + # Set flag to trigger executor reload + self._layer_allocation_changed = True + # Set status to INITIALIZING to prevent scheduler from sending requests + # during rebalancing + self.status = ServerState.INITIALIZING + logger.info( + "Layer allocation updated. Executor will reload on next check. " + "Status set to INITIALIZING to prevent new requests." + ) else: logger.warning(f"Heartbeat response: {response}") else: diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index 0819ea6d..dead8b84 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -97,6 +97,8 @@ def __init__( # GPU/SGLang Specialized Configs attention_backend: Optional[str] = "torch_native", moe_runner_backend: Optional[str] = "auto", + # Optional gradient server for layer reallocation detection + gradient_server: Optional[Any] = None, ): # Backend self.device = get_current_device() @@ -144,6 +146,9 @@ def __init__( self.finished_batch = [] self.start_layer = start_layer self.end_layer = end_layer + self._should_stop = False # Flag to gracefully stop the executor + # Reference to gradient server for layer reallocation detection + self.gradient_server = gradient_server self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == self.config.get("num_hidden_layers") @@ -279,9 +284,9 @@ def __init__( ) @classmethod - def create_from_args(cls, args: argparse.Namespace): + def create_from_args(cls, args: argparse.Namespace, gradient_server=None): """Create executor from command line arguments.""" - return cls(**create_executor_config(args)) + return cls(**create_executor_config(args, gradient_server)) def recv_requests_from_http(self) -> List[Request]: """Receives requests from http frontend""" @@ -1144,7 +1149,8 @@ def run_loop(self): logger.debug( f"Executor for layers [{self.start_layer}, {self.end_layer}) starting run loop..." ) - while True: + self._should_stop = False + while not self._should_stop: # 1. Ingest new requests from the http frontend if self.is_first_peer: http_requests = self.recv_requests_from_http() @@ -1161,6 +1167,14 @@ def run_loop(self): ) self.finished_batch = [] + # Check for layer reallocation signal (before batch processing) + if self.gradient_server is not None and self.gradient_server._layer_allocation_changed: + logger.info( + "Layer reallocation detected. Stopping executor to reload with new layers." + ) + self._should_stop = True + break + # 4. Admit requests into running set up to capacity, then form batch self.scheduler.admit_requests() # 4.1 Check for request timeouts and abort timed out requests @@ -1249,15 +1263,36 @@ def run_loop_in_background(self): def shutdown(self): """Shuts down the executor.""" logger.debug("Executor shutting down...") - self.recv_from_peer_socket.close() - self.send_to_peer_socket.close() - self.recv_from_ipc_socket.close() - self.send_to_ipc_socket.close() - self.zmq_context.term() + self._should_stop = True + import time + + time.sleep(0.1) # Give run_loop a moment to exit gracefully + + try: + all_requests = [req for _, _, _, req in self.scheduler._request_queue] + list( + self.scheduler._running_requests.values() + ) + for req in all_requests: + try: + self.scheduler.evict_request(req.request_id, RequestStatus.CANCELLED) + except Exception: + pass + except Exception: + pass + + try: + self.recv_from_peer_socket.close() + self.send_to_peer_socket.close() + self.recv_from_ipc_socket.close() + self.send_to_ipc_socket.close() + self.zmq_context.term() + except Exception as e: + logger.debug(f"Error closing sockets (may already be closed): {e}") + logger.debug("Executor shutdown complete.") -def create_executor_config(args: argparse.Namespace): +def create_executor_config(args: argparse.Namespace, gradient_server=None): """Create executor configuration from command line arguments.""" config = { @@ -1280,5 +1315,6 @@ def create_executor_config(args: argparse.Namespace): "executor_output_ipc_addr": args.executor_output_ipc, "attention_backend": args.attention_backend, "moe_runner_backend": args.moe_runner_backend, + "gradient_server": gradient_server, } return config diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index fd815688..b8d354c7 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -472,3 +472,28 @@ def launch_http_server(args): process = mp.Process(target=http_server.run) process.start() return process + + +def stop_http_server(http_server_process): + """ + Stop HTTP server process if it exists. + """ + if http_server_process is not None: + logger.info("Stopping HTTP server process...") + try: + http_server_process.kill() + http_server_process.join() + except Exception as e: + logger.error(f"Failed to terminate HTTP server process: {e}") + return None + return http_server_process + + +def restart_http_server(args, http_server_process): + """ + Restart HTTP server with new args. + Stops the old server if it exists and starts a new one. + """ + http_server_process = stop_http_server(http_server_process) + logger.info("Restarting HTTP server...") + return launch_http_server(args)