From b972f09a13cb67ebed42d913f1886a8d36a7dcc2 Mon Sep 17 00:00:00 2001 From: sibianl Date: Thu, 9 Oct 2025 16:11:58 +0800 Subject: [PATCH] fix(node): Close http server process in finally and fix the problem that finally does not take effect --- src/parallax/cli.py | 5 +-- src/parallax/launch.py | 49 +++++++++++++++++++++--------- src/parallax/server/http_server.py | 4 ++- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/src/parallax/cli.py b/src/parallax/cli.py index a59a2cb2..5873d44f 100644 --- a/src/parallax/cli.py +++ b/src/parallax/cli.py @@ -9,6 +9,7 @@ import argparse import os +import signal import subprocess import sys from pathlib import Path @@ -90,7 +91,7 @@ def run_command(args): if sub_process is not None: try: # Gracefully terminate the subprocess - sub_process.terminate() + sub_process.send_signal(signal.SIGINT) # Wait for the subprocess to exit gracefully try: sub_process.wait(timeout=5) @@ -176,7 +177,7 @@ def join_command(args): try: logger.info("Terminating subprocess...") # Gracefully terminate the subprocess - sub_process.terminate() + sub_process.send_signal(signal.SIGINT) logger.info("Subprocess terminated, waiting for exit...") # Wait for the subprocess to exit gracefully try: diff --git a/src/parallax/launch.py b/src/parallax/launch.py index f1a10648..5066d435 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -16,6 +16,7 @@ import multiprocessing import tempfile +import threading from parallax.p2p.server import ServerState, launch_p2p_server from parallax.server.executor import Executor @@ -34,6 +35,10 @@ if __name__ == "__main__": multiprocessing.set_start_method("spawn", force=True) + + gradient_server = None + http_server_process = None + executor = None try: args = parse_args() logger.debug(f"args: {args}") @@ -45,7 +50,6 @@ logger.debug(f"executor_input_addr: {args.executor_input_ipc}") logger.debug(f"executor_output_addr: {args.executor_output_ipc}") - gradient_server = None # Hard code for mlx-community models if get_current_device() == "mlx": mlx_model_repo = MLX_MODEL_NAME_MAP.get(args.model_path, None) @@ -55,7 +59,7 @@ if args.scheduler_addr is None: # only launch http server on head node if args.start_layer == 0: - launch_http_server(args) + http_server_process = launch_http_server(args) executor = Executor.create_from_args(args) launch_p2p_server( initial_peers=args.initial_peers, @@ -111,19 +115,36 @@ gradient_server.status = ServerState.INITIALIZING # only launch http server on head node if args.start_layer == 0: - launch_http_server(args) + http_server_process = launch_http_server(args) executor = Executor.create_from_args(args) - try: - if gradient_server is not None: - gradient_server.status = ServerState.READY - executor.run_loop() - except KeyboardInterrupt: - logger.debug("Received interrupt signal, shutting down...") - finally: - if gradient_server is not None: - gradient_server.shutdown() - executor.shutdown() - + if gradient_server is not None: + gradient_server.status = ServerState.READY + executor.run_loop() + except KeyboardInterrupt: + logger.debug("Received interrupt signal, shutting down...") except Exception as e: logger.exception(e) + finally: + t = None + if http_server_process is not None: + + def terminate_http_server_process(process): + logger.debug("Terminating HTTP server process...") + try: + process.kill() + process.join() + except Exception as e: + logger.error(f"Failed to terminate HTTP server process: {e}") + + if http_server_process is not None: + t = threading.Thread( + target=terminate_http_server_process, args=(http_server_process,) + ) + t.start() + if gradient_server is not None: + gradient_server.shutdown() + if executor is not None: + executor.shutdown() + if t is not None: + t.join() diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index bfcf15c5..8e853b76 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -455,4 +455,6 @@ def launch_http_server(args): It creates a sub-process for the http server. """ http_server = ParallaxHttpServer(args) - mp.Process(target=http_server.run).start() + process = mp.Process(target=http_server.run) + process.start() + return process