Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions src/parallax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def check_python_version():

def get_project_root():
"""Get the project root directory."""
# Find the project root by looking for pyproject.toml
# Search for the project root by looking for pyproject.toml in parent directories
current_dir = Path(__file__).parent
while current_dir != current_dir.parent:
if (current_dir / "pyproject.toml").exists():
return current_dir
current_dir = current_dir.parent

# Fallback to current working directory
# If not found, fallback to current working directory
return Path.cwd()


Expand All @@ -51,24 +51,42 @@ def run_command(args):
print(f"Error: Backend main.py not found at {backend_main}")
sys.exit(1)

# Build the command
# Build the command to run the backend main.py
cmd = [sys.executable, str(backend_main), "--dht-port", "5001", "--port", "3001"]

# Add optional arguments
# Add optional arguments if provided
if args.model_name:
cmd.extend(["--model-name", args.model_name])
if args.init_nodes_num:
cmd.extend(["--init-nodes-num", str(args.init_nodes_num)])

logger.info(f"Running command: {' '.join(cmd)}")

# Use Popen instead of run to control the subprocess
sub_process = None
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Command failed with exit code {e.returncode}")
sys.exit(e.returncode)
sub_process = subprocess.Popen(cmd)
# Wait for the subprocess to finish
return_code = sub_process.wait()
if return_code != 0:
logger.error(f"Command failed with exit code {return_code}")
sys.exit(return_code)
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")
if sub_process is not None:
try:
# Gracefully terminate the subprocess
sub_process.terminate()
# Wait for the subprocess to exit gracefully
try:
sub_process.wait(timeout=5)
except subprocess.TimeoutExpired:
# If the process does not exit in 5 seconds, force kill
logger.info("Process didn't terminate gracefully, forcing kill...")
sub_process.kill()
sub_process.wait()
except Exception as e:
logger.error(f"Failed to terminate subprocess: {e}")
sys.exit(0)


Expand All @@ -87,11 +105,11 @@ def join_command(args):
print(f"Error: Launch script not found at {launch_script}")
sys.exit(1)

# Set environment variable
# Set environment variable for the subprocess
env = os.environ.copy()
env["SGL_ENABLE_JIT_DEEPGEMM"] = "0"

# Build the command
# Build the command to run the launch.py script
cmd = [
sys.executable,
str(launch_script),
Expand All @@ -114,13 +132,36 @@ def join_command(args):
logger.info(f"Running command: {' '.join(cmd)}")
logger.info(f"Scheduler address: {args.scheduler_addr}")

# Use Popen instead of run to control the subprocess
sub_process = None
try:
subprocess.run(cmd, check=True, env=env)
except subprocess.CalledProcessError as e:
logger.error(f"Command failed with exit code {e.returncode}")
sys.exit(e.returncode)
sub_process = subprocess.Popen(cmd, env=env)
# Wait for the subprocess to finish
return_code = sub_process.wait()
if return_code != 0:
logger.error(f"Command failed with exit code {return_code}")
sys.exit(return_code)
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")
if sub_process is not None:
try:
logger.info("Terminating subprocess...")
# Gracefully terminate the subprocess
sub_process.terminate()
logger.info("Subprocess terminated, waiting for exit...")
# Wait for the subprocess to exit gracefully
try:
sub_process.wait(timeout=5)
except subprocess.TimeoutExpired:
# If the process does not exit in 5 seconds, force kill
logger.info("Process didn't terminate gracefully, forcing kill...")
sub_process.kill()
sub_process.wait()
logger.info("Subprocess exited gracefully.")
except Exception as e:
logger.error(f"Failed to terminate subprocess: {e}")
else:
logger.info("Subprocess not found, skipping shutdown...")
sys.exit(0)


Expand All @@ -139,14 +180,14 @@ def main():

subparsers = parser.add_subparsers(dest="command", help="Available commands")

# Run command
# Add 'run' command parser
run_parser = subparsers.add_parser(
"run", help="Start the Parallax scheduler (equivalent to scripts/start.sh)"
)
run_parser.add_argument("-n", "--init-nodes-num", type=int, help="Number of initial nodes")
run_parser.add_argument("-m", "--model-name", type=str, help="Model name")

# Join command
# Add 'join' command parser
join_parser = subparsers.add_parser(
"join", help="Join a distributed cluster (equivalent to scripts/join.sh)"
)
Expand Down
9 changes: 5 additions & 4 deletions src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,17 @@ def get_node_info(self, is_update: bool = False):

def shutdown(self):
self.stop_event.set()
if self.announcer is not None:
self.announcer.join()
if self.routing_table_updater is not None:
self.routing_table_updater.join()

self.status = ServerState.OFFLINE
if self.scheduler_addr is not None:
logger.info(f"Leave scheduler: {self.lattica.peer_id()}")
self.scheduler_stub.node_leave(self.get_node_info(is_update=True))

if self.announcer is not None:
self.announcer.join()
if self.routing_table_updater is not None:
self.routing_table_updater.join()


def launch_p2p_server(
initial_peers: List[str],
Expand Down