Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from parallax.server.http_server import launch_http_server
from parallax.server.server_args import parse_args
from parallax.utils.utils import get_current_device
from parallax_utils.ascii_anime import display_parallax_join
from parallax_utils.logging_config import get_logger

logger = get_logger("parallax.launch")
Expand All @@ -47,7 +48,6 @@
try:
args = parse_args()
logger.debug(f"args: {args}")

args.recv_from_peer_addr = f"ipc://{tempfile.NamedTemporaryFile().name}"
args.send_to_peer_addr = f"ipc://{tempfile.NamedTemporaryFile().name}"
args.executor_input_ipc = f"ipc://{tempfile.NamedTemporaryFile().name}"
Expand All @@ -62,6 +62,8 @@
args.model_path = mlx_model_repo
logger.debug(f"Replace mlx model path: {mlx_model_repo}")
if args.scheduler_addr is None:
display_parallax_join(args.model_path)

# only launch http server on head node
if args.start_layer == 0:
http_server_process = launch_http_server(args)
Expand Down Expand Up @@ -118,6 +120,9 @@
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}"
)
gradient_server.status = ServerState.INITIALIZING

display_parallax_join(args.model_path)

# only launch http server on head node
if args.start_layer == 0:
http_server_process = launch_http_server(args)
Expand Down
2 changes: 0 additions & 2 deletions src/parallax/server/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
pad_inputs,
pad_prefix_caches,
)
from parallax_utils.ascii_anime import display_parallax_join
from parallax_utils.logging_config import get_logger
from parallax_utils.utils import compute_max_batch_size

Expand Down Expand Up @@ -274,7 +273,6 @@ def __init__(
self.send_to_ipc_socket = get_zmq_socket(
self.zmq_context, zmq.PUSH, executor_output_ipc_addr, bind=False
)
display_parallax_join(model_repo)

@classmethod
def create_from_args(cls, args: argparse.Namespace):
Expand Down