Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/parallax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import argparse
import os
import signal
import subprocess
import sys
from pathlib import Path
Expand Down Expand Up @@ -90,7 +91,7 @@ def run_command(args):
if sub_process is not None:
try:
# Gracefully terminate the subprocess
sub_process.terminate()
sub_process.send_signal(signal.SIGINT)
# Wait for the subprocess to exit gracefully
try:
sub_process.wait(timeout=5)
Expand Down Expand Up @@ -176,7 +177,7 @@ def join_command(args):
try:
logger.info("Terminating subprocess...")
# Gracefully terminate the subprocess
sub_process.terminate()
sub_process.send_signal(signal.SIGINT)
logger.info("Subprocess terminated, waiting for exit...")
# Wait for the subprocess to exit gracefully
try:
Expand Down
49 changes: 35 additions & 14 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import multiprocessing
import tempfile
import threading

from parallax.p2p.server import ServerState, launch_p2p_server
from parallax.server.executor import Executor
Expand All @@ -34,6 +35,10 @@

if __name__ == "__main__":
multiprocessing.set_start_method("spawn", force=True)

gradient_server = None
http_server_process = None
executor = None
try:
args = parse_args()
logger.debug(f"args: {args}")
Expand All @@ -45,7 +50,6 @@

logger.debug(f"executor_input_addr: {args.executor_input_ipc}")
logger.debug(f"executor_output_addr: {args.executor_output_ipc}")
gradient_server = None
# Hard code for mlx-community models
if get_current_device() == "mlx":
mlx_model_repo = MLX_MODEL_NAME_MAP.get(args.model_path, None)
Expand All @@ -55,7 +59,7 @@
if args.scheduler_addr is None:
# only launch http server on head node
if args.start_layer == 0:
launch_http_server(args)
http_server_process = launch_http_server(args)
executor = Executor.create_from_args(args)
launch_p2p_server(
initial_peers=args.initial_peers,
Expand Down Expand Up @@ -111,19 +115,36 @@
gradient_server.status = ServerState.INITIALIZING
# only launch http server on head node
if args.start_layer == 0:
launch_http_server(args)
http_server_process = launch_http_server(args)
executor = Executor.create_from_args(args)

try:
if gradient_server is not None:
gradient_server.status = ServerState.READY
executor.run_loop()
except KeyboardInterrupt:
logger.debug("Received interrupt signal, shutting down...")
finally:
if gradient_server is not None:
gradient_server.shutdown()
executor.shutdown()

if gradient_server is not None:
gradient_server.status = ServerState.READY
executor.run_loop()
except KeyboardInterrupt:
logger.debug("Received interrupt signal, shutting down...")
except Exception as e:
logger.exception(e)
finally:
t = None
if http_server_process is not None:

def terminate_http_server_process(process):
logger.debug("Terminating HTTP server process...")
try:
process.kill()
process.join()
except Exception as e:
logger.error(f"Failed to terminate HTTP server process: {e}")

if http_server_process is not None:
t = threading.Thread(
target=terminate_http_server_process, args=(http_server_process,)
)
t.start()
if gradient_server is not None:
gradient_server.shutdown()
if executor is not None:
executor.shutdown()
if t is not None:
t.join()
4 changes: 3 additions & 1 deletion src/parallax/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,4 +455,6 @@ def launch_http_server(args):
It creates a sub-process for the http server.
"""
http_server = ParallaxHttpServer(args)
mp.Process(target=http_server.run).start()
process = mp.Process(target=http_server.run)
process.start()
return process