Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
--end-layer 28
"""

import argparse
import multiprocessing
import os
import tempfile
Expand All @@ -23,7 +24,7 @@
from parallax.server.executor import Executor
from parallax.server.http_server import launch_http_server
from parallax.server.server_args import parse_args
from parallax.utils.utils import get_current_device
from parallax.utils.utils import get_current_device, fetch_model_from_hf
from parallax_utils.ascii_anime import display_parallax_join
from parallax_utils.logging_config import get_logger, set_log_level

Expand All @@ -44,12 +45,19 @@
"zai-org/GLM-4.6": "mlx-community/GLM-4.6-4bit",
}


def run_executor_process(args):
"""Run executor as a subprocess"""
executor = Executor.create_from_args(args)
executor.run_loop()


if __name__ == "__main__":
multiprocessing.set_start_method("spawn", force=True)

gradient_server = None
http_server_process = None
executor = None
executor_procs = []
try:
args = parse_args()
set_log_level(args.log_level)
Expand All @@ -70,22 +78,25 @@
if mlx_model_repo is not None:
args.model_path = mlx_model_repo
logger.debug(f"Replace mlx model path: {mlx_model_repo}")

if args.scheduler_addr is None:
# Launch without scheduler
if args.log_level != "DEBUG":
display_parallax_join(args.model_path)
check_latest_release()

config = fetch_model_from_hf(args.model_path)
# only launch http server on head node
if args.start_layer == 0:
http_server_process = launch_http_server(args)
executor = Executor.create_from_args(args)

launch_p2p_server(
initial_peers=args.initial_peers,
scheduler_addr=args.scheduler_addr,
relay_servers=args.relay_servers,
pp_start_layer=args.start_layer,
pp_end_layer=args.end_layer,
hidden_layers=executor.config.get("num_hidden_layers"),
hidden_layers=config.get("num_hidden_layers"),
tcp_port=args.tcp_port,
udp_port=args.udp_port,
dht_prefix=args.dht_prefix,
Expand All @@ -99,6 +110,7 @@
max_sequence_length=args.max_sequence_length,
)
else:
# Join scheduler
gradient_server = launch_p2p_server(
initial_peers=args.initial_peers,
scheduler_addr=args.scheduler_addr,
Expand Down Expand Up @@ -136,14 +148,26 @@
display_parallax_join(args.model_path)
check_latest_release()

fetch_model_from_hf(args.model_path)
# only launch http server on head node
if args.start_layer == 0:
http_server_process = launch_http_server(args)
executor = Executor.create_from_args(args)

tp_rank_range = range(args.tp_size)
for tp_rank in tp_rank_range:
args_copy = argparse.Namespace(**vars(args))
args_copy.tp_rank = tp_rank
proc = multiprocessing.Process(
target=run_executor_process,
args=(args_copy,),
)
proc.start()
executor_procs.append(proc)

if gradient_server is not None:
gradient_server.status = ServerState.READY
executor.run_loop()
for executor_process in executor_procs:
executor_process.join()
except KeyboardInterrupt:
logger.debug("Received interrupt signal, shutting down...")
except Exception as e:
Expand All @@ -167,7 +191,5 @@ def terminate_http_server_process(process):
t.start()
if gradient_server is not None:
gradient_server.shutdown()
if executor is not None:
executor.shutdown()
if t is not None:
t.join()
Loading