Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ async def serve_index():
if __name__ == "__main__":
args = parse_args()
logger.info(f"args: {args}")

display_parallax_run()
if args.log_level != "DEBUG":
display_parallax_run()
host_maddrs = args.host_maddrs
dht_port = args.dht_port
if args.dht_port is not None:
Expand Down
8 changes: 8 additions & 0 deletions src/backend/server/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def parse_args() -> argparse.Namespace:

parser.add_argument("--port", type=int, default=5000, help="Port to listen on")

parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Log level",
)

parser.add_argument("--model-name", type=str, default=None, help="Model name")

parser.add_argument("--init-nodes-num", type=int, default=None, help="Number of initial nodes")
Expand Down
243 changes: 148 additions & 95 deletions src/parallax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,73 +42,156 @@ def get_project_root():
return Path.cwd()


def run_command(args):
"""Run the scheduler (equivalent to scripts/start.sh)."""
check_python_version()

project_root = get_project_root()
backend_main = project_root / "src" / "backend" / "main.py"

if not backend_main.exists():
print(f"Error: Backend main.py not found at {backend_main}")
sys.exit(1)

# Build the command to run the backend main.py
cmd = [
sys.executable,
str(backend_main),
"--dht-port",
"5001",
"--port",
"3001",
]

# Add optional arguments if provided
if args.model_name:
cmd.extend(["--model-name", args.model_name])
if args.init_nodes_num:
cmd.extend(["--init-nodes-num", str(args.init_nodes_num)])
if args.use_relay:
cmd.extend(get_relay_params())

def _flag_present(args_list: list[str], flag_names: list[str]) -> bool:
"""Return True if any of the given flags is present in args_list.

Supports forms: "--flag value", "--flag=value", "-f value", "-f=value".
"""
if not args_list:
return False
flags_set = set(flag_names)
for i, token in enumerate(args_list):
if token in flags_set:
return True
for flag in flags_set:
if token.startswith(flag + "="):
return True
return False


def _find_flag_value(args_list: list[str], flag_names: list[str]) -> str | None:
"""Find the value for the first matching flag in args_list, if present.

Returns the associated value for forms: "--flag value" or "--flag=value" or
"-f value" or "-f=value". Returns None if not found or value is missing.
"""
if not args_list:
return None
flags_set = set(flag_names)
for i, token in enumerate(args_list):
if token in flags_set:
# expect value in next token if exists and is not another flag
if i + 1 < len(args_list) and not args_list[i + 1].startswith("-"):
return args_list[i + 1]
return None
for flag in flags_set:
prefix = flag + "="
if token.startswith(prefix):
return token[len(prefix) :]
return None


def _execute_with_graceful_shutdown(cmd: list[str], env: dict[str, str] | None = None) -> None:
"""Execute a command in a subprocess and handle graceful shutdown on Ctrl-C.

This centralizes the common Popen + signal handling logic shared by
run_command and join_command.
"""
logger.info(f"Running command: {' '.join(cmd)}")

# Use Popen instead of run to control the subprocess
sub_process = None
try:
sub_process = subprocess.Popen(cmd)
# Start in a new session so we can signal the entire process group
sub_process = subprocess.Popen(cmd, env=env, start_new_session=True)
# Wait for the subprocess to finish
return_code = sub_process.wait()
if return_code != 0:
logger.error(f"Command failed with exit code {return_code}")
sys.exit(return_code)
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")

# If another Ctrl-C arrives during cleanup, force-kill the whole group immediately
def _force_kill_handler(signum, frame):
try:
os.killpg(sub_process.pid, signal.SIGKILL)
except Exception:
try:
sub_process.kill()
except Exception:
pass
os._exit(130)

try:
signal.signal(signal.SIGINT, _force_kill_handler)
except Exception:
pass

if sub_process is not None:
try:
# Gracefully terminate the subprocess
sub_process.send_signal(signal.SIGINT)
logger.info("Terminating subprocess group...")
# Gracefully terminate the entire process group
try:
os.killpg(sub_process.pid, signal.SIGINT)
except Exception:
# Fall back to signaling just the child process
sub_process.send_signal(signal.SIGINT)

logger.info("Waiting for subprocess to exit...")
# Wait for the subprocess to exit gracefully
try:
sub_process.wait(timeout=5)
except subprocess.TimeoutExpired:
# If the process does not exit in 5 seconds, force kill
logger.info("Process didn't terminate gracefully, forcing kill...")
sub_process.kill()
sub_process.wait()
logger.info("SIGINT timeout; sending SIGTERM to process group...")
try:
os.killpg(sub_process.pid, signal.SIGTERM)
except Exception:
sub_process.terminate()
try:
sub_process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.info("SIGTERM timeout; forcing SIGKILL on process group...")
try:
os.killpg(sub_process.pid, signal.SIGKILL)
except Exception:
sub_process.kill()
sub_process.wait()
logger.info("Subprocess exited.")
except Exception as e:
logger.error(f"Failed to terminate subprocess: {e}")
else:
logger.info("Subprocess not found, skipping shutdown...")
sys.exit(0)


def join_command(args):
"""Join a distributed cluster (equivalent to scripts/join.sh)."""
def run_command(args, passthrough_args: list[str] | None = None):
"""Run the scheduler (equivalent to scripts/start.sh)."""
check_python_version()

if not args.scheduler_addr:
print("Error: Scheduler address is required. Use -s or --scheduler-addr")
project_root = get_project_root()
backend_main = project_root / "src" / "backend" / "main.py"

if not backend_main.exists():
print(f"Error: Backend main.py not found at {backend_main}")
sys.exit(1)

# Build the command to run the backend main.py
passthrough_args = passthrough_args or []
cmd = [sys.executable, str(backend_main)]
if not _flag_present(passthrough_args, ["--dht-port"]):
cmd.extend(["--dht-port", "5001"])
if not _flag_present(passthrough_args, ["--port"]):
cmd.extend(["--port", "3001"])

# Add optional arguments if provided
if args.model_name:
cmd.extend(["--model-name", args.model_name])
if args.init_nodes_num:
cmd.extend(["--init-nodes-num", str(args.init_nodes_num)])
if args.use_relay:
cmd.extend(get_relay_params())

# Append any passthrough args (unrecognized by this CLI) directly to the command
if passthrough_args:
cmd.extend(passthrough_args)

_execute_with_graceful_shutdown(cmd)


def join_command(args, passthrough_args: list[str] | None = None):
"""Join a distributed cluster (equivalent to scripts/join.sh)."""
check_python_version()

project_root = get_project_root()
launch_script = project_root / "src" / "parallax" / "launch.py"

Expand All @@ -121,64 +204,33 @@ def join_command(args):
env["SGL_ENABLE_JIT_DEEPGEMM"] = "0"

# Build the command to run the launch.py script
cmd = [
sys.executable,
str(launch_script),
"--max-num-tokens-per-batch",
"4096",
"--max-sequence-length",
"2048",
"--max-batch-size",
"8",
"--kv-block-size",
"1024",
"--host",
"0.0.0.0",
"--port",
"3000",
"--scheduler-addr",
args.scheduler_addr,
]
passthrough_args = passthrough_args or []

cmd = [sys.executable, str(launch_script)]
if not _flag_present(passthrough_args, ["--max-num-tokens-per-batch"]):
cmd.extend(["--max-num-tokens-per-batch", "4096"])
if not _flag_present(passthrough_args, ["--max-sequence-length"]):
cmd.extend(["--max-sequence-length", "2048"])
if not _flag_present(passthrough_args, ["--max-batch-size"]):
cmd.extend(["--max-batch-size", "8"])
if not _flag_present(passthrough_args, ["--kv-block-size"]):
cmd.extend(["--kv-block-size", "1024"])
# The scheduler address is now taken directly from the parsed arguments.
cmd.extend(["--scheduler-addr", args.scheduler_addr])

# Relay logic based on effective scheduler address
if args.use_relay or (
args.scheduler_addr != "auto" and not str(args.scheduler_addr).startswith("/")
):
logger.info("Using public relay servers")
cmd.extend(get_relay_params())

logger.info(f"Running command: {' '.join(cmd)}")
logger.info(f"Scheduler address: {args.scheduler_addr}")
# Append any passthrough args (unrecognized by this CLI) directly to the command
if passthrough_args:
cmd.extend(passthrough_args)

# Use Popen instead of run to control the subprocess
sub_process = None
try:
sub_process = subprocess.Popen(cmd, env=env)
# Wait for the subprocess to finish
return_code = sub_process.wait()
if return_code != 0:
logger.error(f"Command failed with exit code {return_code}")
sys.exit(return_code)
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")
if sub_process is not None:
try:
logger.info("Terminating subprocess...")
# Gracefully terminate the subprocess
sub_process.send_signal(signal.SIGINT)
logger.info("Subprocess terminated, waiting for exit...")
# Wait for the subprocess to exit gracefully
try:
sub_process.wait(timeout=5)
except subprocess.TimeoutExpired:
# If the process does not exit in 5 seconds, force kill
logger.info("Process didn't terminate gracefully, forcing kill...")
sub_process.kill()
sub_process.wait()
logger.info("Subprocess exited gracefully.")
except Exception as e:
logger.error(f"Failed to terminate subprocess: {e}")
else:
logger.info("Subprocess not found, skipping shutdown...")
sys.exit(0)
logger.info(f"Scheduler address: {args.scheduler_addr}")
_execute_with_graceful_shutdown(cmd, env=env)


def main():
Expand Down Expand Up @@ -224,16 +276,17 @@ def main():
"-r", "--use-relay", action="store_true", help="Use public relay servers"
)

args = parser.parse_args()
# Accept unknown args and pass them through to the underlying python command
args, passthrough_args = parser.parse_known_args()

if not args.command:
parser.print_help()
sys.exit(1)

if args.command == "run":
run_command(args)
run_command(args, passthrough_args)
elif args.command == "join":
join_command(args)
join_command(args, passthrough_args)
else:
parser.print_help()
sys.exit(1)
Expand Down
6 changes: 4 additions & 2 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
args.model_path = mlx_model_repo
logger.debug(f"Replace mlx model path: {mlx_model_repo}")
if args.scheduler_addr is None:
display_parallax_join(args.model_path)
if args.log_level != "DEBUG":
display_parallax_join(args.model_path)

# only launch http server on head node
if args.start_layer == 0:
Expand Down Expand Up @@ -122,7 +123,8 @@
)
gradient_server.status = ServerState.INITIALIZING

display_parallax_join(args.model_path)
if args.log_level != "DEBUG":
display_parallax_join(args.model_path)

# only launch http server on head node
if args.start_layer == 0:
Expand Down