Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from common.version_check import check_latest_release
from parallax.p2p.server import ServerState, launch_p2p_server
from parallax.server.executor import Executor
from parallax.server.http_server import launch_http_server
from parallax.server.http_server import launch_http_server, stop_http_server
from parallax.server.server_args import parse_args
from parallax_utils.ascii_anime import display_parallax_join
from parallax_utils.logging_config import get_logger, set_log_level
Expand Down Expand Up @@ -79,6 +79,9 @@
param_hosting_ratio=args.param_hosting_ratio,
kv_cache_ratio=args.kv_cache_ratio,
)
if gradient_server is not None:
gradient_server.status = ServerState.READY
executor.run_loop()
else:
gradient_server = launch_p2p_server(
initial_peers=args.initial_peers,
Expand Down Expand Up @@ -117,32 +120,64 @@
# only launch http server on head node
if args.start_layer == 0:
http_server_process = launch_http_server(args)
executor = Executor.create_from_args(args)

if gradient_server is not None:
gradient_server.status = ServerState.READY
executor.run_loop()
# Main execution loop with layer reallocation support
while True:
try:
executor = Executor.create_from_args(args, gradient_server=gradient_server)
if gradient_server is not None:
gradient_server.status = ServerState.READY

executor.run_loop()

# Check if layer allocation changed (executor exited due to reallocation)
if gradient_server is not None and gradient_server._layer_allocation_changed:
logger.warning(
"Layer allocation changed! Reloading executor with new layers..."
)
executor.shutdown()

if args.start_layer == 0:
http_server_process = stop_http_server(http_server_process)
if gradient_server.block_start_index == 0:
http_server_process = launch_http_server(args)

# Update args with new layer allocation
args.start_layer = gradient_server.block_start_index
args.end_layer = gradient_server.block_end_index
if gradient_server.model_name:
args.model_path = gradient_server.model_name

logger.info(
f"Creating new executor with layers [{args.start_layer}, {args.end_layer})"
)

gradient_server._layer_allocation_changed = False
continue # Create new executor in next iteration
else:
break # Normal exit
except KeyboardInterrupt:
logger.debug("Received interrupt signal, shutting down...")
break
except Exception as e:
logger.exception(f"Executor error: {e}")
# If layer allocation changed, try to reload
if gradient_server is not None and gradient_server._layer_allocation_changed:
logger.info("Attempting to reload executor after error...")
if executor is not None:
executor.shutdown()
continue
else:
raise
except KeyboardInterrupt:
logger.debug("Received interrupt signal, shutting down...")
except Exception as e:
logger.exception(e)
finally:
t = None
if http_server_process is not None:

def terminate_http_server_process(process):
logger.debug("Terminating HTTP server process...")
try:
process.kill()
process.join()
except Exception as e:
logger.error(f"Failed to terminate HTTP server process: {e}")

if http_server_process is not None:
t = threading.Thread(
target=terminate_http_server_process, args=(http_server_process,)
)
t.start()
t = threading.Thread(target=stop_http_server, args=(http_server_process,))
t.start()
if gradient_server is not None:
gradient_server.shutdown()
if executor is not None:
Expand Down
25 changes: 25 additions & 0 deletions src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def __init__(
self.announcer = None
self.connection_handler = None
self.stop_event = threading.Event()
self._layer_allocation_changed = False

def build_lattica(self):
self.lattica = Lattica.builder().with_listen_addrs(self.host_maddrs)
Expand Down Expand Up @@ -584,6 +585,30 @@ def _announcer_thread():
f"Heartbeat: Node {self.lattica.peer_id()}... "
f"Model: {model_name}, Layers: [{start_layer}, {end_layer})"
)
# Check if layer allocation changed
if (
start_layer != self.block_start_index
or end_layer != self.block_end_index
):
logger.warning(
f"Layer allocation changed! "
f"Current: [{self.block_start_index}, {self.block_end_index}) -> "
f"New: [{start_layer}, {end_layer})"
)
# Update layer allocation
self.block_start_index = start_layer
self.block_end_index = end_layer
if model_name:
self.model_name = model_name
# Set flag to trigger executor reload
self._layer_allocation_changed = True
# Set status to INITIALIZING to prevent scheduler from sending requests
# during rebalancing
self.status = ServerState.INITIALIZING
logger.info(
"Layer allocation updated. Executor will reload on next check. "
"Status set to INITIALIZING to prevent new requests."
)
else:
logger.warning(f"Heartbeat response: {response}")
else:
Expand Down
54 changes: 45 additions & 9 deletions src/parallax/server/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(
# GPU/SGLang Specialized Configs
attention_backend: Optional[str] = "torch_native",
moe_runner_backend: Optional[str] = "auto",
# Optional gradient server for layer reallocation detection
gradient_server: Optional[Any] = None,
):
# Backend
self.device = get_current_device()
Expand Down Expand Up @@ -144,6 +146,9 @@ def __init__(
self.finished_batch = []
self.start_layer = start_layer
self.end_layer = end_layer
self._should_stop = False # Flag to gracefully stop the executor
# Reference to gradient server for layer reallocation detection
self.gradient_server = gradient_server

self.is_first_peer = start_layer == 0
self.is_last_peer = end_layer == self.config.get("num_hidden_layers")
Expand Down Expand Up @@ -279,9 +284,9 @@ def __init__(
)

@classmethod
def create_from_args(cls, args: argparse.Namespace):
def create_from_args(cls, args: argparse.Namespace, gradient_server=None):
"""Create executor from command line arguments."""
return cls(**create_executor_config(args))
return cls(**create_executor_config(args, gradient_server))

def recv_requests_from_http(self) -> List[Request]:
"""Receives requests from http frontend"""
Expand Down Expand Up @@ -1144,7 +1149,8 @@ def run_loop(self):
logger.debug(
f"Executor for layers [{self.start_layer}, {self.end_layer}) starting run loop..."
)
while True:
self._should_stop = False
while not self._should_stop:
# 1. Ingest new requests from the http frontend
if self.is_first_peer:
http_requests = self.recv_requests_from_http()
Expand All @@ -1161,6 +1167,14 @@ def run_loop(self):
)
self.finished_batch = []

# Check for layer reallocation signal (before batch processing)
if self.gradient_server is not None and self.gradient_server._layer_allocation_changed:
logger.info(
"Layer reallocation detected. Stopping executor to reload with new layers."
)
self._should_stop = True
break

# 4. Admit requests into running set up to capacity, then form batch
self.scheduler.admit_requests()
# 4.1 Check for request timeouts and abort timed out requests
Expand Down Expand Up @@ -1249,15 +1263,36 @@ def run_loop_in_background(self):
def shutdown(self):
"""Shuts down the executor."""
logger.debug("Executor shutting down...")
self.recv_from_peer_socket.close()
self.send_to_peer_socket.close()
self.recv_from_ipc_socket.close()
self.send_to_ipc_socket.close()
self.zmq_context.term()
self._should_stop = True
import time

time.sleep(0.1) # Give run_loop a moment to exit gracefully

try:
all_requests = [req for _, _, _, req in self.scheduler._request_queue] + list(
self.scheduler._running_requests.values()
)
for req in all_requests:
try:
self.scheduler.evict_request(req.request_id, RequestStatus.CANCELLED)
except Exception:
pass
except Exception:
pass

try:
self.recv_from_peer_socket.close()
self.send_to_peer_socket.close()
self.recv_from_ipc_socket.close()
self.send_to_ipc_socket.close()
self.zmq_context.term()
except Exception as e:
logger.debug(f"Error closing sockets (may already be closed): {e}")

logger.debug("Executor shutdown complete.")


def create_executor_config(args: argparse.Namespace):
def create_executor_config(args: argparse.Namespace, gradient_server=None):
"""Create executor configuration from command line arguments."""

config = {
Expand All @@ -1280,5 +1315,6 @@ def create_executor_config(args: argparse.Namespace):
"executor_output_ipc_addr": args.executor_output_ipc,
"attention_backend": args.attention_backend,
"moe_runner_backend": args.moe_runner_backend,
"gradient_server": gradient_server,
}
return config
25 changes: 25 additions & 0 deletions src/parallax/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,28 @@ def launch_http_server(args):
process = mp.Process(target=http_server.run)
process.start()
return process


def stop_http_server(http_server_process):
"""
Stop HTTP server process if it exists.
"""
if http_server_process is not None:
logger.info("Stopping HTTP server process...")
try:
http_server_process.kill()
http_server_process.join()
except Exception as e:
logger.error(f"Failed to terminate HTTP server process: {e}")
return None
return http_server_process


def restart_http_server(args, http_server_process):
"""
Restart HTTP server with new args.
Stops the old server if it exists and starts a new one.
"""
http_server_process = stop_http_server(http_server_process)
logger.info("Restarting HTTP server...")
return launch_http_server(args)