Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions src/parallax/server/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
prefill_priority: int = 0,
micro_batch_ratio: int = 2,
scheduler_wait_ms: int = 500,
request_timeout_s: Optional[int] = 600,
# Metrics Configs
layer_latency_update_every: int = 4096,
# KV Cache Configs
Expand Down Expand Up @@ -242,6 +243,8 @@ def __init__(
micro_batch_ratio=micro_batch_ratio,
is_first_peer=self.is_first_peer,
tokenizer=self.tokenizer,
kv_cache_manager=self.kv_cache_manager if self.device == "mlx" else None,
request_timeout_s=request_timeout_s,
)
logger.debug(
f"Scheduler initialized (max_batch_size={max_batch_size}, max_tokens={max_num_tokens_per_batch}, wait_ms={scheduler_wait_ms})"
Expand Down Expand Up @@ -781,7 +784,7 @@ def _handle_cuda_input_requests(self, requests: List[Request]):
req, IntermediateRequest
), "Non-first peers must receive IntermediateRequests."
if req.is_finished or req.hidden_states is None:
self.scheduler.evict_request(req.request_id, req.status)
self.scheduler.evict_request(req.request_id)
release_cuda_request(self.running_batch, req.request_id)
if not self.is_last_peer:
self.finished_batch.append(req)
Expand All @@ -805,8 +808,6 @@ def _handle_input_requests(self, requests: List[Request]):
# or IntermediateRequests from the last peer.
for req in requests:
if isinstance(req, InitialRequest):
if not self.kv_cache_manager.has_request(req.request_id):
self.kv_cache_manager.add_request(req, req.total_length)
self.scheduler.enque_request(req)
elif isinstance(req, IntermediateRequest):
original_req = self.scheduler.get_running_request(req.request_id)
Expand Down Expand Up @@ -876,14 +877,12 @@ def _handle_input_requests(self, requests: List[Request]):
f"kv cache manager has {self.kv_cache_manager.tokens_in_cache} tokens, "
f"memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB"
)
self.scheduler.evict_request(req.request_id, req.status)
self.scheduler.evict_request(req.request_id)
if not self.is_last_peer:
self.finished_batch.append(req)
else:
# This is an active request, add it to the scheduler queue to be processed.
self.scheduler.enque_request(req)
if not self.kv_cache_manager.has_request(req.request_id):
self.kv_cache_manager.add_request(req, req.total_length)

def _prepare_next_single_request(self, request: Request, hidden_states: Any) -> Request:
"""Handle request state changes both inter and intra peers.
Expand Down Expand Up @@ -1058,7 +1057,7 @@ def _process_batch_mlx(
]
# k_caches shape: (num_layers, B, num_kv_heads, L_padded, head_dim)
logger.debug(
f"Processed batch with {len(prepared_inputs['requests'])} requests, "
f"Processing batch with {len(prepared_inputs['requests'])} requests, "
f"request status: {prepared_inputs['requests'][0].status}, "
f"hidden_states shape: {hidden_states.shape}, "
f"k_caches shape: {k_caches.shape}, "
Expand Down Expand Up @@ -1117,6 +1116,29 @@ def process_batch(
)
return ret

def _release_and_evict_request(self, rid: str):
"""Release per-request resources and evict from scheduler. Best-effort, never raises."""
# Release resources
if self.device == "cuda":
from parallax.sglang.batch_info import release_cuda_request

try:
release_cuda_request(self.running_batch, rid)
except Exception:
pass
else:
try:
if hasattr(self, "kv_cache_manager") and self.kv_cache_manager is not None:
self.kv_cache_manager.release_request(rid)
except Exception:
pass

# Evict from scheduler
try:
self.scheduler.evict_request(rid)
except Exception:
pass

def run_loop(self):
"""The main loop of the executor."""
logger.debug(
Expand All @@ -1139,12 +1161,25 @@ def run_loop(self):
)
self.finished_batch = []

# 4. Check if we should form a batch
if not self.scheduler.should_dispatch():
time.sleep(0.01) # prevent busy waiting
continue
# 4. Admit requests into running set up to capacity, then form batch
self.scheduler.admit_requests()
# 4.1 Check for request timeouts and abort timed out requests
try:
timed_out_reqs = self.scheduler.get_timed_out_requests()
if timed_out_reqs:
for req in timed_out_reqs:
rid = req.request_id
logger.warning(
f"Request {rid} exceeded timeout ({req.timeout_s}s). Aborting and releasing resources."
)
self._release_and_evict_request(rid)

# 5. Form a batch from the scheduler's queue
# Notify downstream peers to abort if this peer is the first peer in a pipeline
if self.is_first_peer and not self.is_last_peer:
self.finished_batch.append(req)
except Exception:
# Non-fatal; continue serving
pass
batch_to_process = self.scheduler.form_batch()
if not batch_to_process:
continue
Expand Down Expand Up @@ -1206,13 +1241,7 @@ def run_loop(self):
logger.exception(f"Error processing batch: {e}")
# Naive error handling: release and evict all requests in the batch
for req in batch_to_process:
self.scheduler.evict_request(req.request_id, req.status)
if self.device == "cuda":
from parallax.sglang.batch_info import release_cuda_request

release_cuda_request(self.running_batch, req.request_id)
else:
self.kv_cache_manager.release_request(req.request_id)
self._release_and_evict_request(req.request_id)

def run_loop_in_background(self):
"""Run the executor loop in the background."""
Expand Down
2 changes: 2 additions & 0 deletions src/parallax/server/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def __init__(
self.routing_table = routing_table
self.sampling_params = sampling_params or SamplingParams()
self.abort = False
self.ready_for_next_step = False
self.last_updated_time: Optional[float] = None

@property
def is_finished(self) -> bool:
Expand Down
Loading